From 2fa4d933ccde9fc792af6b0dfd9bc8298c420fab Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 5 Nov 2025 19:57:11 +0800 Subject: [PATCH] Refactor script gen with block level code cache (#3910) --- .../script_generations/generate_script.py | 168 ++++++++++-------- skyvern/forge/sdk/workflow/service.py | 93 ++++++++-- skyvern/services/workflow_script_service.py | 72 +++++++- 3 files changed, 250 insertions(+), 83 deletions(-) diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index a16e8a55..740ca760 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -10,6 +10,7 @@ import asyncio import hashlib import keyword import re +from dataclasses import dataclass from typing import Any import libcst as cst @@ -31,6 +32,15 @@ GENERATE_CODE_AI_MODE_PROACTIVE = "proactive" GENERATE_CODE_AI_MODE_FALLBACK = "fallback" +@dataclass +class ScriptBlockSource: + label: str + code: str + run_signature: str | None + workflow_run_id: str | None + workflow_run_block_id: str | None + + # --------------------------------------------------------------------- # # 1. helpers # # --------------------------------------------------------------------- # @@ -104,6 +114,7 @@ ACTIONS_WITH_XPATH = [ "upload_file", "select_option", ] +ACTIONS_OPT_OUT_INTENTION_FOR_PROMPT = ["extract"] INDENT = " " * 4 DOUBLE_INDENT = " " * 8 @@ -421,7 +432,7 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output: ) ) intention = act.get("intention") or act.get("reasoning") or "" - if intention: + if intention and method not in ACTIONS_OPT_OUT_INTENTION_FOR_PROMPT: args.extend( [ cst.Arg( @@ -432,6 +443,7 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output: ), ] ) + _mark_last_arg_as_comma(args) # Only use indented parentheses if we have arguments if args: @@ -1694,10 +1706,25 @@ async def generate_workflow_script_python_code( script_id: str | None = None, script_revision_id: str | None = None, pending: bool = False, + cached_blocks: dict[str, ScriptBlockSource] | None = None, + updated_block_labels: set[str] | None = None, ) -> str: """ Build a LibCST Module and emit .code (PEP-8-formatted source). + + Cached script blocks can be reused by providing them via `cached_blocks`. Any labels present in + `updated_block_labels` will be regenerated from the latest workflow run execution data. """ + cached_blocks = cached_blocks or {} + updated_block_labels = set(updated_block_labels or []) + + # Drop cached entries that do not have usable source + cached_blocks = {label: source for label, source in cached_blocks.items() if source.code} + # Always regenerate the orchestrator block so it stays aligned with the workflow definition + cached_blocks.pop(settings.WORKFLOW_START_BLOCK_LABEL, None) + + if task_v2_child_blocks is None: + task_v2_child_blocks = {} # --- imports -------------------------------------------------------- imports: list[cst.BaseStatement] = [ cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]), @@ -1746,33 +1773,47 @@ async def generate_workflow_script_python_code( generated_model_cls = _build_generated_model_from_schema(generated_schema) # --- blocks --------------------------------------------------------- - block_fns = [] + block_fns: list[cst.CSTNode] = [] task_v1_blocks = [block for block in blocks if block["block_type"] in SCRIPT_TASK_BLOCKS] task_v2_blocks = [block for block in blocks if block["block_type"] == "task_v2"] - if task_v2_child_blocks is None: - task_v2_child_blocks = {} + def append_block_code(block_code: str) -> None: + nonlocal block_fns + parsed = cst.parse_module(block_code) + if block_fns: + block_fns.append(cst.EmptyLine()) + block_fns.append(cst.EmptyLine()) + block_fns.extend(parsed.body) # Handle task v1 blocks (excluding child blocks of task_v2) for idx, task in enumerate(task_v1_blocks): - # Skip if this is a child block of a task_v2 block if task.get("parent_task_v2_label"): continue - block_fn_def = _build_block_fn(task, actions_by_task.get(task.get("task_id", ""), [])) + block_name = task.get("label") or task.get("title") or task.get("task_id") or f"task_{idx}" + cached_source = cached_blocks.get(block_name) + use_cached = cached_source is not None and block_name not in updated_block_labels + + if use_cached: + assert cached_source is not None + block_code = cached_source.code + run_signature = cached_source.run_signature + block_workflow_run_id = cached_source.workflow_run_id + block_workflow_run_block_id = cached_source.workflow_run_block_id + else: + block_fn_def = _build_block_fn(task, actions_by_task.get(task.get("task_id", ""), [])) + temp_module = cst.Module(body=[block_fn_def]) + block_code = temp_module.code + + block_stmt = _build_block_statement(task) + run_signature_module = cst.Module(body=[block_stmt]) + run_signature = run_signature_module.code.strip() + + block_workflow_run_id = task.get("workflow_run_id") or run_id + block_workflow_run_block_id = task.get("workflow_run_block_id") - # Create script block if we have script context if script_id and script_revision_id and organization_id: try: - block_name = task.get("label") or task.get("title") or task.get("task_id") or f"task_{idx}" - temp_module = cst.Module(body=[block_fn_def]) - block_code = temp_module.code - - # Extract the run signature (the statement that calls skyvern.action/extract/etc) - block_stmt = _build_block_statement(task) - run_signature_module = cst.Module(body=[block_stmt]) - run_signature = run_signature_module.code.strip() - await create_or_update_script_block( block_code=block_code, script_revision_id=script_revision_id, @@ -1781,84 +1822,67 @@ async def generate_workflow_script_python_code( block_label=block_name, update=pending, run_signature=run_signature, - workflow_run_id=task.get("workflow_run_id"), - workflow_run_block_id=task.get("workflow_run_block_id"), + workflow_run_id=block_workflow_run_id, + workflow_run_block_id=block_workflow_run_block_id, ) except Exception as e: LOG.error("Failed to create script block", error=str(e), exc_info=True) - # Continue without script block creation if it fails - block_fns.append(block_fn_def) - if idx < len(task_v1_blocks) - 1: - block_fns.append(cst.EmptyLine()) - block_fns.append(cst.EmptyLine()) + append_block_code(block_code) # Handle task_v2 blocks - for idx, task_v2 in enumerate(task_v2_blocks): + for task_v2 in task_v2_blocks: task_v2_label = task_v2.get("label") or f"task_v2_{task_v2.get('workflow_run_block_id')}" child_blocks = task_v2_child_blocks.get(task_v2_label, []) - # Create the task_v2 function - task_v2_fn_def = _build_task_v2_block_fn(task_v2, child_blocks) + cached_source = cached_blocks.get(task_v2_label) + use_cached = cached_source is not None and task_v2_label not in updated_block_labels + + block_code = "" + run_signature = None + block_workflow_run_id = task_v2.get("workflow_run_id") or run_id + block_workflow_run_block_id = task_v2.get("workflow_run_block_id") + + if use_cached: + assert cached_source is not None + block_code = cached_source.code + run_signature = cached_source.run_signature + block_workflow_run_id = cached_source.workflow_run_id + block_workflow_run_block_id = cached_source.workflow_run_block_id + else: + task_v2_fn_def = _build_task_v2_block_fn(task_v2, child_blocks) + task_v2_block_body: list[cst.CSTNode] = [task_v2_fn_def] + + for child_block in child_blocks: + if child_block.get("block_type") in SCRIPT_TASK_BLOCKS and child_block.get("block_type") != "task_v2": + child_fn_def = _build_block_fn(child_block, actions_by_task.get(child_block.get("task_id", ""), [])) + task_v2_block_body.append(cst.EmptyLine()) + task_v2_block_body.append(cst.EmptyLine()) + task_v2_block_body.append(child_fn_def) + + temp_module = cst.Module(body=task_v2_block_body) + block_code = temp_module.code + + task_v2_stmt = _build_block_statement(task_v2) + run_signature = cst.Module(body=[task_v2_stmt]).code.strip() - # Create script block for task_v2 that includes both the main function and child functions if script_id and script_revision_id and organization_id: try: - # Build the complete module for this task_v2 block - task_v2_block_body = [task_v2_fn_def] - - # Add child block functions - for child_block in child_blocks: - if ( - child_block.get("block_type") in SCRIPT_TASK_BLOCKS - and child_block.get("block_type") != "task_v2" - ): - child_fn_def = _build_block_fn( - child_block, actions_by_task.get(child_block.get("task_id", ""), []) - ) - task_v2_block_body.append(cst.EmptyLine()) - task_v2_block_body.append(cst.EmptyLine()) - task_v2_block_body.append(child_fn_def) - - # Create the complete module for this task_v2 block - temp_module = cst.Module(body=task_v2_block_body) - task_v2_block_code = temp_module.code - - block_name = task_v2.get("label") or task_v2.get("title") or f"task_v2_{idx}" - - # Extract the run signature for task_v2 block - task_v2_stmt = _build_block_statement(task_v2) - run_signature_module = cst.Module(body=[task_v2_stmt]) - run_signature = run_signature_module.code.strip() - await create_or_update_script_block( - block_code=task_v2_block_code, + block_code=block_code, script_revision_id=script_revision_id, script_id=script_id, organization_id=organization_id, - block_label=block_name, + block_label=task_v2_label, update=pending, run_signature=run_signature, - workflow_run_id=task_v2.get("workflow_run_id"), - workflow_run_block_id=task_v2.get("workflow_run_block_id"), + workflow_run_id=block_workflow_run_id, + workflow_run_block_id=block_workflow_run_block_id, ) except Exception as e: LOG.error("Failed to create task_v2 script block", error=str(e), exc_info=True) - # Continue without script block creation if it fails - block_fns.append(task_v2_fn_def) - - # Create individual functions for child blocks - for child_block in child_blocks: - if child_block.get("block_type") in SCRIPT_TASK_BLOCKS and child_block.get("block_type") != "task_v2": - child_fn_def = _build_block_fn(child_block, actions_by_task.get(child_block.get("task_id", ""), [])) - block_fns.append(cst.EmptyLine()) - block_fns.append(cst.EmptyLine()) - block_fns.append(child_fn_def) - - if idx < len(task_v2_blocks) - 1: - block_fns.append(cst.EmptyLine()) - block_fns.append(cst.EmptyLine()) + append_block_code(block_code) # --- runner --------------------------------------------------------- run_fn = _build_run_fn(blocks, workflow_run_request) diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index f2b8f39d..a78178d2 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -635,7 +635,7 @@ class WorkflowService: # Unified execution: execute blocks one by one, using script code when available if is_script_run is False: workflow_script = None - workflow_run = await self._execute_workflow_blocks( + workflow_run, blocks_to_update = await self._execute_workflow_blocks( workflow=workflow, workflow_run=workflow_run, organization=organization, @@ -663,6 +663,7 @@ class WorkflowService: workflow=workflow, workflow_run=workflow_run, block_labels=block_labels, + blocks_to_update=blocks_to_update, ) else: LOG.info( @@ -690,7 +691,7 @@ class WorkflowService: block_labels: list[str] | None = None, block_outputs: dict[str, Any] | None = None, workflow_script: WorkflowScript | None = None, - ) -> WorkflowRun: + ) -> tuple[WorkflowRun, set[str]]: organization_id = organization.organization_id workflow_run_id = workflow_run.workflow_run_id top_level_blocks = workflow.workflow_definition.blocks @@ -699,6 +700,7 @@ class WorkflowService: # Load script blocks if workflow_script is provided script_blocks_by_label: dict[str, Any] = {} loaded_script_module = None + blocks_to_update: set[str] = set() if workflow_script: LOG.info( @@ -946,6 +948,13 @@ class WorkflowService: workflow_run_id=workflow_run_id, failure_reason="Block result is None" ) break + if ( + not block_executed_with_code + and block.label + and block_result.status == BlockStatus.completed + and not getattr(block, "disable_cache", False) + ): + blocks_to_update.add(block.label) if block_result.status == BlockStatus.canceled: LOG.info( f"Block with type {block.block_type} at index {block_idx}/{blocks_cnt - 1} was canceled for workflow run {workflow_run_id}, cancelling workflow run", @@ -1064,7 +1073,7 @@ class WorkflowService: workflow_run_id=workflow_run_id, failure_reason=failure_reason ) break - return workflow_run + return workflow_run, blocks_to_update async def create_workflow( self, @@ -3289,13 +3298,17 @@ class WorkflowService: workflow: Workflow, workflow_run: WorkflowRun, block_labels: list[str] | None = None, + blocks_to_update: set[str] | None = None, ) -> None: code_gen = workflow_run.code_gen + blocks_to_update = set(blocks_to_update or []) + LOG.info( "Generate script?", block_labels=block_labels, code_gen=code_gen, workflow_run_id=workflow_run.workflow_run_id, + blocks_to_update=list(blocks_to_update), ) if block_labels and not code_gen: @@ -3308,16 +3321,74 @@ class WorkflowService: workflow_run, block_labels, ) + if existing_script: - LOG.info( - "Found cached script for workflow. Skipping script generation", - workflow_id=workflow.workflow_id, - workflow_run_id=workflow_run.workflow_run_id, - cache_key_value=rendered_cache_key_value, - script_id=existing_script.script_id, + cached_block_labels: set[str] = set() + script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id( script_revision_id=existing_script.script_revision_id, - run_with=workflow_run.run_with, + organization_id=workflow.organization_id, ) + for script_block in script_blocks: + if script_block.script_block_label: + cached_block_labels.add(script_block.script_block_label) + + definition_labels = {block.label for block in workflow.workflow_definition.blocks if block.label} + definition_labels.add(settings.WORKFLOW_START_BLOCK_LABEL) + cached_block_labels.add(settings.WORKFLOW_START_BLOCK_LABEL) + + if cached_block_labels != definition_labels: + missing_labels = definition_labels - cached_block_labels + if missing_labels: + blocks_to_update.update(missing_labels) + # Always rebuild the orchestrator if the definition changed + blocks_to_update.add(settings.WORKFLOW_START_BLOCK_LABEL) + + should_regenerate = bool(blocks_to_update) or bool(code_gen) + + if not should_regenerate: + LOG.info( + "Workflow script already up to date; skipping regeneration", + workflow_id=workflow.workflow_id, + workflow_run_id=workflow_run.workflow_run_id, + cache_key_value=rendered_cache_key_value, + script_id=existing_script.script_id, + script_revision_id=existing_script.script_revision_id, + run_with=workflow_run.run_with, + ) + return + + # delete the existing workflow scripts if any + await app.DATABASE.delete_workflow_scripts_by_permanent_id( + organization_id=workflow.organization_id, + workflow_permanent_id=workflow.workflow_permanent_id, + script_ids=[existing_script.script_id], + ) + + # create a new script + regenerated_script = await app.DATABASE.create_script( + organization_id=workflow.organization_id, + run_id=workflow_run.workflow_run_id, + ) + + await workflow_script_service.generate_workflow_script( + workflow_run=workflow_run, + workflow=workflow, + script=regenerated_script, + rendered_cache_key_value=rendered_cache_key_value, + cached_script=existing_script, + updated_block_labels=blocks_to_update, + ) + aio_task_primary_key = f"{regenerated_script.script_id}_{regenerated_script.version}" + if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map: + aio_tasks = app.ARTIFACT_MANAGER.upload_aiotasks_map[aio_task_primary_key] + if aio_tasks: + await asyncio.gather(*aio_tasks) + else: + LOG.warning( + "No upload aio tasks found for regenerated script", + script_id=regenerated_script.script_id, + version=regenerated_script.version, + ) return created_script = await app.DATABASE.create_script( @@ -3330,6 +3401,8 @@ class WorkflowService: workflow=workflow, script=created_script, rendered_cache_key_value=rendered_cache_key_value, + cached_script=None, + updated_block_labels=None, ) aio_task_primary_key = f"{created_script.script_id}_{created_script.version}" if aio_task_primary_key in app.ARTIFACT_MANAGER.upload_aiotasks_map: diff --git a/skyvern/services/workflow_script_service.py b/skyvern/services/workflow_script_service.py index 0dd052dc..4c9f5b84 100644 --- a/skyvern/services/workflow_script_service.py +++ b/skyvern/services/workflow_script_service.py @@ -3,7 +3,8 @@ import base64 import structlog from jinja2.sandbox import SandboxedEnvironment -from skyvern.core.script_generations.generate_script import generate_workflow_script_python_code +from skyvern.config import settings +from skyvern.core.script_generations.generate_script import ScriptBlockSource, generate_workflow_script_python_code from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input from skyvern.forge import app from skyvern.forge.sdk.core import skyvern_context @@ -45,6 +46,7 @@ async def generate_or_update_pending_workflow_script( script=script, rendered_cache_key_value=rendered_cache_key_value, pending=True, + cached_script=script, ) @@ -104,12 +106,62 @@ async def get_workflow_script( return None, rendered_cache_key_value +async def _load_cached_script_block_sources( + script: Script, + organization_id: str, +) -> dict[str, ScriptBlockSource]: + """ + Load existing script block sources (code + metadata) for a script revision so they can be reused. + """ + cached_blocks: dict[str, ScriptBlockSource] = {} + + script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id( + script_revision_id=script.script_revision_id, + organization_id=organization_id, + ) + + for script_block in script_blocks: + if not script_block.script_block_label: + continue + + code_str: str | None = None + if script_block.script_file_id: + script_file = await app.DATABASE.get_script_file_by_id( + script_revision_id=script.script_revision_id, + file_id=script_block.script_file_id, + organization_id=organization_id, + ) + if script_file and script_file.artifact_id: + artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id) + if artifact: + file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact) + if isinstance(file_content, bytes): + code_str = file_content.decode("utf-8") + elif isinstance(file_content, str): + code_str = file_content + + if not code_str: + continue + + cached_blocks[script_block.script_block_label] = ScriptBlockSource( + label=script_block.script_block_label, + code=code_str, + run_signature=script_block.run_signature, + workflow_run_id=script_block.workflow_run_id, + workflow_run_block_id=script_block.workflow_run_block_id, + ) + + return cached_blocks + + async def generate_workflow_script( workflow_run: WorkflowRun, workflow: Workflow, script: Script, rendered_cache_key_value: str, pending: bool = False, + cached_script: Script | None = None, + updated_block_labels: set[str] | None = None, ) -> None: try: LOG.info( @@ -119,10 +171,26 @@ async def generate_workflow_script( workflow_name=workflow.title, cache_key_value=rendered_cache_key_value, ) + cached_block_sources: dict[str, ScriptBlockSource] = {} + if cached_script: + cached_block_sources = await _load_cached_script_block_sources(cached_script, workflow.organization_id) + codegen_input = await transform_workflow_run_to_code_gen_input( workflow_run_id=workflow_run.workflow_run_id, organization_id=workflow.organization_id, ) + + block_labels = [block.get("label") for block in codegen_input.workflow_blocks if block.get("label")] + + if updated_block_labels is None: + updated_block_labels = {label for label in block_labels if label} + else: + updated_block_labels = set(updated_block_labels) + + missing_labels = {label for label in block_labels if label and label not in cached_block_sources} + updated_block_labels.update(missing_labels) + updated_block_labels.add(settings.WORKFLOW_START_BLOCK_LABEL) + python_src = await generate_workflow_script_python_code( file_name=codegen_input.file_name, workflow_run_request=codegen_input.workflow_run, @@ -134,6 +202,8 @@ async def generate_workflow_script( script_id=script.script_id, script_revision_id=script.script_revision_id, pending=pending, + cached_blocks=cached_block_sources, + updated_block_labels=updated_block_labels, ) except Exception: LOG.error("Failed to generate workflow script source", exc_info=True)