script gen post action (#3480)

This commit is contained in:
Shuchang Zheng
2025-09-19 08:50:21 -07:00
committed by GitHub
parent b4669f7477
commit c5280782b0
17 changed files with 536 additions and 264 deletions

View File

@@ -15,6 +15,7 @@ class SkyvernContext:
workflow_id: str | None = None
workflow_permanent_id: str | None = None
workflow_run_id: str | None = None
root_workflow_run_id: str | None = None
task_v2_id: str | None = None
max_steps_override: int | None = None
browser_session_id: str | None = None

View File

@@ -107,7 +107,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus,
)
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunType
from skyvern.schemas.scripts import Script, ScriptBlock, ScriptFile, ScriptStatus
from skyvern.schemas.scripts import Script, ScriptBlock, ScriptFile, ScriptStatus, WorkflowScript
from skyvern.schemas.steps import AgentStepOutput
from skyvern.schemas.workflows import BlockStatus, BlockType, WorkflowStatus
from skyvern.webeye.actions.actions import Action
@@ -4039,6 +4039,44 @@ class AgentDB:
return convert_to_script_file(script_file) if script_file else None
async def get_script_file_by_path(
self,
script_revision_id: str,
file_path: str,
organization_id: str,
) -> ScriptFile | None:
async with self.Session() as session:
script_file = (
await session.scalars(
select(ScriptFileModel)
.filter_by(script_revision_id=script_revision_id)
.filter_by(file_path=file_path)
.filter_by(organization_id=organization_id)
)
).first()
return convert_to_script_file(script_file) if script_file else None
async def update_script_file(
self,
script_file_id: str,
organization_id: str,
artifact_id: str | None = None,
) -> ScriptFile:
async with self.Session() as session:
script_file = (
await session.scalars(
select(ScriptFileModel).filter_by(file_id=script_file_id).filter_by(organization_id=organization_id)
)
).first()
if script_file:
if artifact_id:
script_file.artifact_id = artifact_id
await session.commit()
await session.refresh(script_file)
return convert_to_script_file(script_file)
else:
raise NotFoundError("Script file not found")
async def get_script_block(
self,
script_block_id: str,
@@ -4054,6 +4092,23 @@ class AgentDB:
).first()
return convert_to_script_block(record) if record else None
async def get_script_block_by_label(
self,
organization_id: str,
script_revision_id: str,
script_block_label: str,
) -> ScriptBlock | None:
async with self.Session() as session:
record = (
await session.scalars(
select(ScriptBlockModel)
.filter_by(script_revision_id=script_revision_id)
.filter_by(script_block_label=script_block_label)
.filter_by(organization_id=organization_id)
)
).first()
return convert_to_script_block(record) if record else None
async def get_script_blocks_by_script_revision_id(
self,
script_revision_id: str,
@@ -4080,6 +4135,7 @@ class AgentDB:
cache_key_value: str,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
status: ScriptStatus = ScriptStatus.published,
) -> None:
"""Create a workflow->script cache mapping entry."""
try:
@@ -4092,6 +4148,7 @@ class AgentDB:
workflow_run_id=workflow_run_id,
cache_key=cache_key,
cache_key_value=cache_key_value,
status=status,
)
session.add(record)
await session.commit()
@@ -4102,12 +4159,32 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_workflow_script(
self,
organization_id: str,
workflow_permanent_id: str,
workflow_run_id: str,
statuses: list[ScriptStatus] | None = None,
) -> WorkflowScript | None:
async with self.Session() as session:
query = (
select(WorkflowScriptModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter_by(workflow_run_id=workflow_run_id)
)
if statuses:
query = query.filter(WorkflowScriptModel.status.in_(statuses))
workflow_script_model = (await session.scalars(query)).first()
return WorkflowScript.model_validate(workflow_script_model) if workflow_script_model else None
async def get_workflow_scripts_by_cache_key_value(
self,
*,
organization_id: str,
workflow_permanent_id: str,
cache_key_value: str,
workflow_run_id: str | None = None,
cache_key: str | None = None,
statuses: list[ScriptStatus] | None = None,
) -> list[Script]:
@@ -4122,6 +4199,10 @@ class AgentDB:
.where(WorkflowScriptModel.cache_key_value == cache_key_value)
.where(WorkflowScriptModel.deleted_at.is_(None))
)
if workflow_run_id:
ws_script_ids_subquery = ws_script_ids_subquery.where(
WorkflowScriptModel.workflow_run_id == workflow_run_id
)
if cache_key is not None:
ws_script_ids_subquery = ws_script_ids_subquery.where(WorkflowScriptModel.cache_key == cache_key)
@@ -4174,6 +4255,7 @@ class AgentDB:
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter_by(cache_key=cache_key)
.filter_by(deleted_at=None)
.filter_by(status="published")
)
if filter:
@@ -4205,6 +4287,7 @@ class AgentDB:
.filter_by(workflow_permanent_id=workflow_permanent_id)
.filter_by(cache_key=cache_key)
.filter_by(deleted_at=None)
.filter_by(status="published")
.offset((page - 1) * page_size)
.limit(page_size)
)
@@ -4220,45 +4303,6 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_workflow_cache_key_value(
self,
organization_id: str,
workflow_permanent_id: str,
cache_key: str,
cache_key_value: str,
script_id: str,
workflow_id: str | None = None,
workflow_run_id: str | None = None,
) -> str:
"""
Insert a new cache key value for a workflow.
Returns the workflow_script_id of the created record.
"""
try:
async with self.Session() as session:
workflow_script = WorkflowScriptModel(
script_id=script_id,
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
cache_key=cache_key,
cache_key_value=cache_key_value,
)
session.add(workflow_script)
await session.commit()
await session.refresh(workflow_script)
return workflow_script.workflow_script_id
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def delete_workflow_cache_key_value(
self,
organization_id: str,

View File

@@ -300,6 +300,7 @@ async def get_workflow_script_blocks(
scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
organization_id=current_org.organization_id,
workflow_permanent_id=workflow_permanent_id,
workflow_run_id=block_script_request.workflow_run_id,
cache_key_value=cache_key_value,
cache_key=cache_key,
statuses=[status] if status else None,

View File

@@ -3060,6 +3060,9 @@ class TaskV2Block(Block):
finally:
context: skyvern_context.SkyvernContext | None = skyvern_context.current()
current_run_id = context.run_id if context and context.run_id else workflow_run_id
root_workflow_run_id = (
context.root_workflow_run_id if context and context.root_workflow_run_id else workflow_run_id
)
skyvern_context.set(
skyvern_context.SkyvernContext(
organization_id=organization_id,
@@ -3067,6 +3070,7 @@ class TaskV2Block(Block):
workflow_id=workflow_run.workflow_id,
workflow_permanent_id=workflow_run.workflow_permanent_id,
workflow_run_id=workflow_run_id,
root_workflow_run_id=root_workflow_run_id,
run_id=current_run_id,
browser_session_id=browser_session_id,
max_screenshot_scrolls=workflow_run.max_screenshot_scrolls,

View File

@@ -1,5 +1,4 @@
import asyncio
import base64
import json
import uuid
from datetime import UTC, datetime
@@ -7,14 +6,11 @@ from typing import Any
import httpx
import structlog
from jinja2.sandbox import SandboxedEnvironment
from skyvern import analytics
from skyvern.client.types.output_parameter import OutputParameter as BlockOutputParameter
from skyvern.config import settings
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT, SAVE_DOWNLOADED_FILES_TIMEOUT
from skyvern.core.script_generations.generate_script import generate_workflow_script as generate_python_workflow_script
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
from skyvern.exceptions import (
BlockNotFound,
BrowserSessionNotFound,
@@ -99,7 +95,6 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus,
)
from skyvern.schemas.runs import ProxyLocation, RunStatus, RunType, WorkflowRunRequest, WorkflowRunResponse
from skyvern.schemas.scripts import FileEncoding, Script, ScriptFileCreate
from skyvern.schemas.workflows import (
BLOCK_YAML_TYPES,
BlockStatus,
@@ -109,7 +104,7 @@ from skyvern.schemas.workflows import (
WorkflowDefinitionYAML,
WorkflowStatus,
)
from skyvern.services import script_service
from skyvern.services import script_service, workflow_script_service
from skyvern.webeye.browser_factory import BrowserState
LOG = structlog.get_logger()
@@ -205,6 +200,7 @@ class WorkflowService:
request_id=request_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
root_workflow_run_id=workflow_run.workflow_run_id,
run_id=current_run_id,
workflow_permanent_id=workflow_run.workflow_permanent_id,
max_steps_override=max_steps_override,
@@ -353,7 +349,7 @@ class WorkflowService:
return workflow_run
# Check if there's a related workflow script that should be used instead
workflow_script, _ = await self._get_workflow_script(workflow, workflow_run, block_labels)
workflow_script, _ = await workflow_script_service.get_workflow_script(workflow, workflow_run, block_labels)
is_script = workflow_script is not None
if workflow_script is not None:
LOG.info(
@@ -365,9 +361,7 @@ class WorkflowService:
)
workflow_run = await self._execute_workflow_script(
script_id=workflow_script.script_id,
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
organization=organization,
browser_session_id=browser_session_id,
)
@@ -375,9 +369,7 @@ class WorkflowService:
workflow_run = await self._execute_workflow_blocks(
workflow=workflow,
workflow_run=workflow_run,
api_key=api_key,
organization=organization,
close_browser_on_completion=close_browser_on_completion,
browser_session_id=browser_session_id,
block_labels=block_labels,
block_outputs=block_outputs,
@@ -422,9 +414,7 @@ class WorkflowService:
self,
workflow: Workflow,
workflow_run: WorkflowRun,
api_key: str,
organization: Organization,
close_browser_on_completion: bool,
browser_session_id: str | None = None,
block_labels: list[str] | None = None,
block_outputs: dict[str, Any] | None = None,
@@ -2457,66 +2447,10 @@ class WorkflowService:
return result
async def _get_workflow_script(
self, workflow: Workflow, workflow_run: WorkflowRun, block_labels: list[str] | None = None
) -> tuple[Script | None, str]:
"""
Check if there's a related workflow script that should be used instead of running the workflow.
Returns the tuple of (script, rendered_cache_key_value).
"""
cache_key = workflow.cache_key or ""
rendered_cache_key_value = ""
if not workflow.generate_script:
return None, rendered_cache_key_value
if block_labels:
# Do not generate script or run script if block_labels is provided
return None, rendered_cache_key_value
try:
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
workflow_run_id=workflow_run.workflow_run_id,
)
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
jinja_sandbox_env = SandboxedEnvironment()
rendered_cache_key_value = jinja_sandbox_env.from_string(cache_key).render(parameters)
# Check if there are existing cached scripts for this workflow + cache_key_value
existing_scripts = await app.DATABASE.get_workflow_scripts_by_cache_key_value(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key_value=rendered_cache_key_value,
)
if existing_scripts:
LOG.info(
"Found cached script for workflow",
workflow_id=workflow.workflow_id,
cache_key_value=rendered_cache_key_value,
workflow_run_id=workflow_run.workflow_run_id,
script_count=len(existing_scripts),
)
return existing_scripts[0], rendered_cache_key_value
return None, rendered_cache_key_value
except Exception as e:
LOG.warning(
"Failed to check for workflow script, proceeding with normal workflow execution",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
error=str(e),
exc_info=True,
)
return None, rendered_cache_key_value
async def _execute_workflow_script(
self,
script_id: str,
workflow: Workflow,
workflow_run: WorkflowRun,
api_key: str,
organization: Organization,
browser_session_id: str | None = None,
) -> WorkflowRun:
@@ -2584,7 +2518,7 @@ class WorkflowService:
# Do not generate script if block_labels is provided
return None
existing_script, rendered_cache_key_value = await self._get_workflow_script(
existing_script, rendered_cache_key_value = await workflow_script_service.get_workflow_script(
workflow,
workflow_run,
block_labels,
@@ -2605,62 +2539,9 @@ class WorkflowService:
run_id=workflow_run.workflow_run_id,
)
# 3) Generate script code from workflow run
try:
LOG.info(
"Generating script for workflow",
workflow_run_id=workflow_run.workflow_run_id,
workflow_id=workflow.workflow_id,
workflow_name=workflow.title,
cache_key_value=rendered_cache_key_value,
)
codegen_input = await transform_workflow_run_to_code_gen_input(
workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id,
)
python_src = await generate_python_workflow_script(
file_name=codegen_input.file_name,
workflow_run_request=codegen_input.workflow_run,
workflow=codegen_input.workflow,
blocks=codegen_input.workflow_blocks,
actions_by_task=codegen_input.actions_by_task,
task_v2_child_blocks=codegen_input.task_v2_child_blocks,
organization_id=workflow.organization_id,
script_id=created_script.script_id,
script_revision_id=created_script.script_revision_id,
)
except Exception:
LOG.error("Failed to generate workflow script source", exc_info=True)
return
# 4) Persist script and files, then record mapping
content_bytes = python_src.encode("utf-8")
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
files = [
ScriptFileCreate(
path="main.py",
content=content_b64,
encoding=FileEncoding.BASE64,
mime_type="text/x-python",
)
]
# Upload script file(s) as artifacts and create rows
await script_service.build_file_tree(
files=files,
organization_id=workflow.organization_id,
script_id=created_script.script_id,
script_version=created_script.version,
script_revision_id=created_script.script_revision_id,
)
# Record the workflow->script mapping for cache lookup
await app.DATABASE.create_workflow_script(
organization_id=workflow.organization_id,
script_id=created_script.script_id,
workflow_permanent_id=workflow.workflow_permanent_id,
cache_key=workflow.cache_key or "",
cache_key_value=rendered_cache_key_value,
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
await workflow_script_service.generate_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
script=created_script,
rendered_cache_key_value=rendered_cache_key_value,
)