Refactor script gen with block level code cache (#3910)

This commit is contained in:
Shuchang Zheng
2025-11-05 19:57:11 +08:00
committed by GitHub
parent 524513dd93
commit 2fa4d933cc
3 changed files with 250 additions and 83 deletions

View File

@@ -10,6 +10,7 @@ import asyncio
import hashlib import hashlib
import keyword import keyword
import re import re
from dataclasses import dataclass
from typing import Any from typing import Any
import libcst as cst import libcst as cst
@@ -31,6 +32,15 @@ GENERATE_CODE_AI_MODE_PROACTIVE = "proactive"
GENERATE_CODE_AI_MODE_FALLBACK = "fallback" GENERATE_CODE_AI_MODE_FALLBACK = "fallback"
@dataclass
class ScriptBlockSource:
label: str
code: str
run_signature: str | None
workflow_run_id: str | None
workflow_run_block_id: str | None
# --------------------------------------------------------------------- # # --------------------------------------------------------------------- #
# 1. helpers # # 1. helpers #
# --------------------------------------------------------------------- # # --------------------------------------------------------------------- #
@@ -104,6 +114,7 @@ ACTIONS_WITH_XPATH = [
"upload_file", "upload_file",
"select_option", "select_option",
] ]
ACTIONS_OPT_OUT_INTENTION_FOR_PROMPT = ["extract"]
INDENT = " " * 4 INDENT = " " * 4
DOUBLE_INDENT = " " * 8 DOUBLE_INDENT = " " * 8
@@ -421,7 +432,7 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output:
) )
) )
intention = act.get("intention") or act.get("reasoning") or "" intention = act.get("intention") or act.get("reasoning") or ""
if intention: if intention and method not in ACTIONS_OPT_OUT_INTENTION_FOR_PROMPT:
args.extend( args.extend(
[ [
cst.Arg( cst.Arg(
@@ -432,6 +443,7 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output:
), ),
] ]
) )
_mark_last_arg_as_comma(args)
# Only use indented parentheses if we have arguments # Only use indented parentheses if we have arguments
if args: if args:
@@ -1694,10 +1706,25 @@ async def generate_workflow_script_python_code(
script_id: str | None = None, script_id: str | None = None,
script_revision_id: str | None = None, script_revision_id: str | None = None,
pending: bool = False, pending: bool = False,
cached_blocks: dict[str, ScriptBlockSource] | None = None,
updated_block_labels: set[str] | None = None,
) -> str: ) -> str:
""" """
Build a LibCST Module and emit .code (PEP-8-formatted source). Build a LibCST Module and emit .code (PEP-8-formatted source).
Cached script blocks can be reused by providing them via `cached_blocks`. Any labels present in
`updated_block_labels` will be regenerated from the latest workflow run execution data.
""" """
cached_blocks = cached_blocks or {}
updated_block_labels = set(updated_block_labels or [])
# Drop cached entries that do not have usable source
cached_blocks = {label: source for label, source in cached_blocks.items() if source.code}
# Always regenerate the orchestrator block so it stays aligned with the workflow definition
cached_blocks.pop(settings.WORKFLOW_START_BLOCK_LABEL, None)
if task_v2_child_blocks is None:
task_v2_child_blocks = {}
# --- imports -------------------------------------------------------- # --- imports --------------------------------------------------------
imports: list[cst.BaseStatement] = [ imports: list[cst.BaseStatement] = [
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]), cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]),
@@ -1746,33 +1773,47 @@ async def generate_workflow_script_python_code(
generated_model_cls = _build_generated_model_from_schema(generated_schema) generated_model_cls = _build_generated_model_from_schema(generated_schema)
# --- blocks --------------------------------------------------------- # --- blocks ---------------------------------------------------------
block_fns = [] block_fns: list[cst.CSTNode] = []
task_v1_blocks = [block for block in blocks if block["block_type"] in SCRIPT_TASK_BLOCKS] task_v1_blocks = [block for block in blocks if block["block_type"] in SCRIPT_TASK_BLOCKS]
task_v2_blocks = [block for block in blocks if block["block_type"] == "task_v2"] task_v2_blocks = [block for block in blocks if block["block_type"] == "task_v2"]
if task_v2_child_blocks is None: def append_block_code(block_code: str) -> None:
task_v2_child_blocks = {} nonlocal block_fns
parsed = cst.parse_module(block_code)
if block_fns:
block_fns.append(cst.EmptyLine())
block_fns.append(cst.EmptyLine())
block_fns.extend(parsed.body)
# Handle task v1 blocks (excluding child blocks of task_v2) # Handle task v1 blocks (excluding child blocks of task_v2)
for idx, task in enumerate(task_v1_blocks): for idx, task in enumerate(task_v1_blocks):
# Skip if this is a child block of a task_v2 block
if task.get("parent_task_v2_label"): if task.get("parent_task_v2_label"):
continue continue
block_fn_def = _build_block_fn(task, actions_by_task.get(task.get("task_id", ""), [])) block_name = task.get("label") or task.get("title") or task.get("task_id") or f"task_{idx}"
cached_source = cached_blocks.get(block_name)
use_cached = cached_source is not None and block_name not in updated_block_labels
if use_cached:
assert cached_source is not None
block_code = cached_source.code
run_signature = cached_source.run_signature
block_workflow_run_id = cached_source.workflow_run_id
block_workflow_run_block_id = cached_source.workflow_run_block_id
else:
block_fn_def = _build_block_fn(task, actions_by_task.get(task.get("task_id", ""), []))
temp_module = cst.Module(body=[block_fn_def])
block_code = temp_module.code
block_stmt = _build_block_statement(task)
run_signature_module = cst.Module(body=[block_stmt])
run_signature = run_signature_module.code.strip()
block_workflow_run_id = task.get("workflow_run_id") or run_id
block_workflow_run_block_id = task.get("workflow_run_block_id")
# Create script block if we have script context
if script_id and script_revision_id and organization_id: if script_id and script_revision_id and organization_id:
try: try:
block_name = task.get("label") or task.get("title") or task.get("task_id") or f"task_{idx}"
temp_module = cst.Module(body=[block_fn_def])
block_code = temp_module.code
# Extract the run signature (the statement that calls skyvern.action/extract/etc)
block_stmt = _build_block_statement(task)
run_signature_module = cst.Module(body=[block_stmt])
run_signature = run_signature_module.code.strip()
await create_or_update_script_block( await create_or_update_script_block(
block_code=block_code, block_code=block_code,
script_revision_id=script_revision_id, script_revision_id=script_revision_id,
@@ -1781,84 +1822,67 @@ async def generate_workflow_script_python_code(
block_label=block_name, block_label=block_name,
update=pending, update=pending,
run_signature=run_signature, run_signature=run_signature,
workflow_run_id=task.get("workflow_run_id"), workflow_run_id=block_workflow_run_id,
workflow_run_block_id=task.get("workflow_run_block_id"), workflow_run_block_id=block_workflow_run_block_id,
) )
except Exception as e: except Exception as e:
LOG.error("Failed to create script block", error=str(e), exc_info=True) LOG.error("Failed to create script block", error=str(e), exc_info=True)
# Continue without script block creation if it fails
block_fns.append(block_fn_def) append_block_code(block_code)
if idx < len(task_v1_blocks) - 1:
block_fns.append(cst.EmptyLine())
block_fns.append(cst.EmptyLine())
# Handle task_v2 blocks # Handle task_v2 blocks
for idx, task_v2 in enumerate(task_v2_blocks): for task_v2 in task_v2_blocks:
task_v2_label = task_v2.get("label") or f"task_v2_{task_v2.get('workflow_run_block_id')}" task_v2_label = task_v2.get("label") or f"task_v2_{task_v2.get('workflow_run_block_id')}"
child_blocks = task_v2_child_blocks.get(task_v2_label, []) child_blocks = task_v2_child_blocks.get(task_v2_label, [])
# Create the task_v2 function cached_source = cached_blocks.get(task_v2_label)
task_v2_fn_def = _build_task_v2_block_fn(task_v2, child_blocks) use_cached = cached_source is not None and task_v2_label not in updated_block_labels
block_code = ""
run_signature = None
block_workflow_run_id = task_v2.get("workflow_run_id") or run_id
block_workflow_run_block_id = task_v2.get("workflow_run_block_id")
if use_cached:
assert cached_source is not None
block_code = cached_source.code
run_signature = cached_source.run_signature
block_workflow_run_id = cached_source.workflow_run_id
block_workflow_run_block_id = cached_source.workflow_run_block_id
else:
task_v2_fn_def = _build_task_v2_block_fn(task_v2, child_blocks)
task_v2_block_body: list[cst.CSTNode] = [task_v2_fn_def]
for child_block in child_blocks:
if child_block.get("block_type") in SCRIPT_TASK_BLOCKS and child_block.get("block_type") != "task_v2":
child_fn_def = _build_block_fn(child_block, actions_by_task.get(child_block.get("task_id", ""), []))
task_v2_block_body.append(cst.EmptyLine())
task_v2_block_body.append(cst.EmptyLine())
task_v2_block_body.append(child_fn_def)
temp_module = cst.Module(body=task_v2_block_body)
block_code = temp_module.code
task_v2_stmt = _build_block_statement(task_v2)
run_signature = cst.Module(body=[task_v2_stmt]).code.strip()
# Create script block for task_v2 that includes both the main function and child functions
if script_id and script_revision_id and organization_id: if script_id and script_revision_id and organization_id:
try: try:
# Build the complete module for this task_v2 block
task_v2_block_body = [task_v2_fn_def]
# Add child block functions
for child_block in child_blocks:
if (
child_block.get("block_type") in SCRIPT_TASK_BLOCKS
and child_block.get("block_type") != "task_v2"
):
child_fn_def = _build_block_fn(
child_block, actions_by_task.get(child_block.get("task_id", ""), [])
)
task_v2_block_body.append(cst.EmptyLine())
task_v2_block_body.append(cst.EmptyLine())
task_v2_block_body.append(child_fn_def)
# Create the complete module for this task_v2 block
temp_module = cst.Module(body=task_v2_block_body)
task_v2_block_code = temp_module.code
block_name = task_v2.get("label") or task_v2.get("title") or f"task_v2_{idx}"
# Extract the run signature for task_v2 block
task_v2_stmt = _build_block_statement(task_v2)
run_signature_module = cst.Module(body=[task_v2_stmt])
run_signature = run_signature_module.code.strip()
await create_or_update_script_block( await create_or_update_script_block(
block_code=task_v2_block_code, block_code=block_code,
script_revision_id=script_revision_id, script_revision_id=script_revision_id,
script_id=script_id, script_id=script_id,
organization_id=organization_id, organization_id=organization_id,
block_label=block_name, block_label=task_v2_label,
update=pending, update=pending,
run_signature=run_signature, run_signature=run_signature,
workflow_run_id=task_v2.get("workflow_run_id"), workflow_run_id=block_workflow_run_id,
workflow_run_block_id=task_v2.get("workflow_run_block_id"), workflow_run_block_id=block_workflow_run_block_id,
) )
except Exception as e: except Exception as e:
LOG.error("Failed to create task_v2 script block", error=str(e), exc_info=True) LOG.error("Failed to create task_v2 script block", error=str(e), exc_info=True)
# Continue without script block creation if it fails
block_fns.append(task_v2_fn_def) append_block_code(block_code)
# Create individual functions for child blocks
for child_block in child_blocks:
if child_block.get("block_type") in SCRIPT_TASK_BLOCKS and child_block.get("block_type") != "task_v2":
child_fn_def = _build_block_fn(child_block, actions_by_task.get(child_block.get("task_id", ""), []))
block_fns.append(cst.EmptyLine())
block_fns.append(cst.EmptyLine())
block_fns.append(child_fn_def)
if idx < len(task_v2_blocks) - 1:
block_fns.append(cst.EmptyLine())
block_fns.append(cst.EmptyLine())
# --- runner --------------------------------------------------------- # --- runner ---------------------------------------------------------
run_fn = _build_run_fn(blocks, workflow_run_request) run_fn = _build_run_fn(blocks, workflow_run_request)

View File

@@ -635,7 +635,7 @@ class WorkflowService:
# Unified execution: execute blocks one by one, using script code when available # Unified execution: execute blocks one by one, using script code when available
if is_script_run is False: if is_script_run is False:
workflow_script = None workflow_script = None
workflow_run = await self._execute_workflow_blocks( workflow_run, blocks_to_update = await self._execute_workflow_blocks(
workflow=workflow, workflow=workflow,
workflow_run=workflow_run, workflow_run=workflow_run,
organization=organization, organization=organization,
@@ -663,6 +663,7 @@ class WorkflowService:
workflow=workflow, workflow=workflow,
workflow_run=workflow_run, workflow_run=workflow_run,
block_labels=block_labels, block_labels=block_labels,
blocks_to_update=blocks_to_update,
) )
else: else:
LOG.info( LOG.info(
@@ -690,7 +691,7 @@ class WorkflowService:
block_labels: list[str] | None = None, block_labels: list[str] | None = None,
block_outputs: dict[str, Any] | None = None, block_outputs: dict[str, Any] | None = None,
workflow_script: WorkflowScript | None = None, workflow_script: WorkflowScript | None = None,
) -> WorkflowRun: ) -> tuple[WorkflowRun, set[str]]:
organization_id = organization.organization_id organization_id = organization.organization_id
workflow_run_id = workflow_run.workflow_run_id workflow_run_id = workflow_run.workflow_run_id
top_level_blocks = workflow.workflow_definition.blocks top_level_blocks = workflow.workflow_definition.blocks
@@ -699,6 +700,7 @@ class WorkflowService:
# Load script blocks if workflow_script is provided # Load script blocks if workflow_script is provided
script_blocks_by_label: dict[str, Any] = {} script_blocks_by_label: dict[str, Any] = {}
loaded_script_module = None loaded_script_module = None
blocks_to_update: set[str] = set()
if workflow_script: if workflow_script:
LOG.info( LOG.info(
@@ -946,6 +948,13 @@ class WorkflowService:
workflow_run_id=workflow_run_id, failure_reason="Block result is None" workflow_run_id=workflow_run_id, failure_reason="Block result is None"
) )
break break
if (
not block_executed_with_code
and block.label
and block_result.status == BlockStatus.completed
and not getattr(block, "disable_cache", False)
):
blocks_to_update.add(block.label)
if block_result.status == BlockStatus.canceled: if block_result.status == BlockStatus.canceled:
LOG.info( LOG.info(
f"Block with type {block.block_type} at index {block_idx}/{blocks_cnt - 1} was canceled for workflow run {workflow_run_id}, cancelling workflow run", f"Block with type {block.block_type} at index {block_idx}/{blocks_cnt - 1} was canceled for workflow run {workflow_run_id}, cancelling workflow run",
@@ -1064,7 +1073,7 @@ class WorkflowService:
workflow_run_id=workflow_run_id, failure_reason=failure_reason workflow_run_id=workflow_run_id, failure_reason=failure_reason
) )
break break
return workflow_run return workflow_run, blocks_to_update
async def create_workflow( async def create_workflow(
self, self,
@@ -3289,13 +3298,17 @@ class WorkflowService:
workflow: Workflow, workflow: Workflow,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
block_labels: list[str] | None = None, block_labels: list[str] | None = None,
blocks_to_update: set[str] | None = None,
) -> None: ) -> None:
code_gen = workflow_run.code_gen code_gen = workflow_run.code_gen
blocks_to_update = set(blocks_to_update or [])
LOG.info( LOG.info(
"Generate script?", "Generate script?",
block_labels=block_labels, block_labels=block_labels,
code_gen=code_gen, code_gen=code_gen,
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
blocks_to_update=list(blocks_to_update),
) )
if block_labels and not code_gen: if block_labels and not code_gen:
@@ -3308,16 +3321,74 @@ class WorkflowService:
workflow_run, workflow_run,
block_labels, block_labels,
) )
if existing_script: if existing_script:
LOG.info( cached_block_labels: set[str] = set()
"Found cached script for workflow. Skipping script generation", script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
cache_key_value=rendered_cache_key_value,
script_id=existing_script.script_id,
script_revision_id=existing_script.script_revision_id, script_revision_id=existing_script.script_revision_id,
run_with=workflow_run.run_with, organization_id=workflow.organization_id,
) )
for script_block in script_blocks:
if script_block.script_block_label:
cached_block_labels.add(script_block.script_block_label)
definition_labels = {block.label for block in workflow.workflow_definition.blocks if block.label}
definition_labels.add(settings.WORKFLOW_START_BLOCK_LABEL)
cached_block_labels.add(settings.WORKFLOW_START_BLOCK_LABEL)
if cached_block_labels != definition_labels:
missing_labels = definition_labels - cached_block_labels
if missing_labels:
blocks_to_update.update(missing_labels)
# Always rebuild the orchestrator if the definition changed
blocks_to_update.add(settings.WORKFLOW_START_BLOCK_LABEL)
should_regenerate = bool(blocks_to_update) or bool(code_gen)
if not should_regenerate:
LOG.info(
"Workflow script already up to date; skipping regeneration",
workflow_id=workflow.workflow_id,
workflow_run_id=workflow_run.workflow_run_id,
cache_key_value=rendered_cache_key_value,
script_id=existing_script.script_id,
script_revision_id=existing_script.script_revision_id,
run_with=workflow_run.run_with,
)
return
# delete the existing workflow scripts if any
await app.DATABASE.delete_workflow_scripts_by_permanent_id(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
script_ids=[existing_script.script_id],
)
# create a new script
regenerated_script = await app.DATABASE.create_script(
organization_id=workflow.organization_id,
run_id=workflow_run.workflow_run_id,
)
await workflow_script_service.generate_workflow_script(
workflow_run=workflow_run,
workflow=workflow,
script=regenerated_script,
rendered_cache_key_value=rendered_cache_key_value,
cached_script=existing_script,
updated_block_labels=blocks_to_update,
)
aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}"
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:
aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key]
if aio_tasks:
await asyncio.gather(*aio_tasks)
else:
LOG.warning(
"No upload aio tasks found for regenerated script",
script_id=regenerated_script.script_id,
version=regenerated_script.version,
)
return return
created_script = await app.DATABASE.create_script( created_script = await app.DATABASE.create_script(
@@ -3330,6 +3401,8 @@ class WorkflowService:
workflow=workflow, workflow=workflow,
script=created_script, script=created_script,
rendered_cache_key_value=rendered_cache_key_value, rendered_cache_key_value=rendered_cache_key_value,
cached_script=None,
updated_block_labels=None,
) )
aio_task_primary_key = f"{created_script.script_id}_{created_script.version}" aio_task_primary_key = f"{created_script.script_id}_{created_script.version}"
if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map: if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map:

View File

@@ -3,7 +3,8 @@ import base64
import structlog import structlog
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
from skyvern.core.script_generations.generate_script import generate_workflow_script_python_code from skyvern.config import settings
from skyvern.core.script_generations.generate_script import ScriptBlockSource, generate_workflow_script_python_code
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
@@ -45,6 +46,7 @@ async def generate_or_update_pending_workflow_script(
script=script, script=script,
rendered_cache_key_value=rendered_cache_key_value, rendered_cache_key_value=rendered_cache_key_value,
pending=True, pending=True,
cached_script=script,
) )
@@ -104,12 +106,62 @@ async def get_workflow_script(
return None, rendered_cache_key_value return None, rendered_cache_key_value
async def _load_cached_script_block_sources(
script: Script,
organization_id: str,
) -> dict[str, ScriptBlockSource]:
"""
Load existing script block sources (code + metadata) for a script revision so they can be reused.
"""
cached_blocks: dict[str, ScriptBlockSource] = {}
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_revision_id=script.script_revision_id,
organization_id=organization_id,
)
for script_block in script_blocks:
if not script_block.script_block_label:
continue
code_str: str | None = None
if script_block.script_file_id:
script_file = await app.DATABASE.get_script_file_by_id(
script_revision_id=script.script_revision_id,
file_id=script_block.script_file_id,
organization_id=organization_id,
)
if script_file and script_file.artifact_id:
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
if artifact:
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
if isinstance(file_content, bytes):
code_str = file_content.decode("utf-8")
elif isinstance(file_content, str):
code_str = file_content
if not code_str:
continue
cached_blocks[script_block.script_block_label] = ScriptBlockSource(
label=script_block.script_block_label,
code=code_str,
run_signature=script_block.run_signature,
workflow_run_id=script_block.workflow_run_id,
workflow_run_block_id=script_block.workflow_run_block_id,
)
return cached_blocks
async def generate_workflow_script( async def generate_workflow_script(
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
workflow: Workflow, workflow: Workflow,
script: Script, script: Script,
rendered_cache_key_value: str, rendered_cache_key_value: str,
pending: bool = False, pending: bool = False,
cached_script: Script | None = None,
updated_block_labels: set[str] | None = None,
) -> None: ) -> None:
try: try:
LOG.info( LOG.info(
@@ -119,10 +171,26 @@ async def generate_workflow_script(
workflow_name=workflow.title, workflow_name=workflow.title,
cache_key_value=rendered_cache_key_value, cache_key_value=rendered_cache_key_value,
) )
cached_block_sources: dict[str, ScriptBlockSource] = {}
if cached_script:
cached_block_sources = await _load_cached_script_block_sources(cached_script, workflow.organization_id)
codegen_input = await transform_workflow_run_to_code_gen_input( codegen_input = await transform_workflow_run_to_code_gen_input(
workflow_run_id=workflow_run.workflow_run_id, workflow_run_id=workflow_run.workflow_run_id,
organization_id=workflow.organization_id, organization_id=workflow.organization_id,
) )
block_labels = [block.get("label") for block in codegen_input.workflow_blocks if block.get("label")]
if updated_block_labels is None:
updated_block_labels = {label for label in block_labels if label}
else:
updated_block_labels = set(updated_block_labels)
missing_labels = {label for label in block_labels if label and label not in cached_block_sources}
updated_block_labels.update(missing_labels)
updated_block_labels.add(settings.WORKFLOW_START_BLOCK_LABEL)
python_src = await generate_workflow_script_python_code( python_src = await generate_workflow_script_python_code(
file_name=codegen_input.file_name, file_name=codegen_input.file_name,
workflow_run_request=codegen_input.workflow_run, workflow_run_request=codegen_input.workflow_run,
@@ -134,6 +202,8 @@ async def generate_workflow_script(
script_id=script.script_id, script_id=script.script_id,
script_revision_id=script.script_revision_id, script_revision_id=script.script_revision_id,
pending=pending, pending=pending,
cached_blocks=cached_block_sources,
updated_block_labels=updated_block_labels,
) )
except Exception: except Exception:
LOG.error("Failed to generate workflow script source", exc_info=True) LOG.error("Failed to generate workflow script source", exc_info=True)