From a903170f1453f0b2440869abbcda140b4bed26f1 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 9 Sep 2025 22:33:59 -0700 Subject: [PATCH] script gen run code block using the block interface (#3401) --- skyvern/__init__.py | 2 + .../script_generations/generate_script.py | 47 +++++++++++++- skyvern/services/script_service.py | 65 +++++++++++++++---- 3 files changed, 99 insertions(+), 15 deletions(-) diff --git a/skyvern/__init__.py b/skyvern/__init__.py index 70c03417..2199a539 100644 --- a/skyvern/__init__.py +++ b/skyvern/__init__.py @@ -36,6 +36,7 @@ from skyvern.services.script_service import ( # noqa: E402 generate_text, # noqa: E402 login, # noqa: E402 render_template, # noqa: E402 + run_code, # noqa: E402 run_script, # noqa: E402 run_task, # noqa: E402 wait, # noqa: E402 @@ -53,6 +54,7 @@ __all__ = [ "generate_text", "login", "render_template", + "run_code", "run_script", "run_task", "setup", diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index 4c8266b9..06c4203e 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -127,7 +127,9 @@ def _value(value: Any) -> cst.BaseExpression: """Convert simple Python objects to CST expressions.""" if isinstance(value, str): if "\n" in value: - return cst.SimpleString('"""' + value.replace('"""', '\\"\\"\\"') + '"""') + # For multi-line strings, use repr() which handles all escaping properly + # This will use triple quotes when appropriate and escape them when needed + return cst.SimpleString(repr(value)) return cst.SimpleString(repr(value)) if isinstance(value, (int, float, bool)) or value is None: return cst.parse_expression(repr(value)) @@ -877,6 +879,47 @@ def _build_goto_statement(block: dict[str, Any]) -> cst.SimpleStatementLine: return cst.SimpleStatementLine([cst.Expr(cst.Await(call))]) +def _build_code_statement(block: dict[str, Any]) -> cst.SimpleStatementLine: + """Build a skyvern.run_code statement.""" + args = [ + cst.Arg( + keyword=cst.Name("code"), + value=_value(block.get("code", "")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ), + cst.Arg( + keyword=cst.Name("label"), + value=_value(block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ), + cst.Arg( + keyword=cst.Name("parameters"), + value=_value(block.get("parameters", None)), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + ), + comma=cst.Comma(), + ), + ] + + call = cst.Call( + func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("run_code")), + args=args, + whitespace_before_args=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + + return cst.SimpleStatementLine([cst.Expr(cst.Await(call))]) + + def __build_base_task_statement(block_title: str, block: dict[str, Any]) -> list[cst.Arg]: args = [ cst.Arg( @@ -986,6 +1029,8 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct stmt = _build_for_loop_statement(block_title, block) elif block_type == "goto_url": stmt = _build_goto_statement(block) + elif block_type == "code": + stmt = _build_code_statement(block) else: # Default case for unknown block types stmt = cst.SimpleStatementLine([cst.Expr(cst.SimpleString(f"# Unknown block type: {block_type}"))]) diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index effd0522..8b4a7a87 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -4,6 +4,7 @@ import hashlib import importlib.util import json import os +import uuid from datetime import datetime from typing import Any, cast @@ -25,7 +26,8 @@ from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus -from skyvern.forge.sdk.workflow.models.block import TaskBlock +from skyvern.forge.sdk.workflow.models.block import CodeBlock, TaskBlock +from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE from skyvern.forge.sdk.workflow.models.workflow import Workflow from skyvern.schemas.runs import RunEngine from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate @@ -451,21 +453,17 @@ async def _update_workflow_block( task_output = TaskOutput.from_task(updated_task, downloaded_files) final_output = task_output.model_dump() - await app.DATABASE.update_workflow_run_block( - workflow_run_block_id=workflow_run_block_id, - organization_id=context.organization_id if context else None, - status=status, - failure_reason=failure_reason, - output=final_output, - ) else: final_output = None - await app.DATABASE.update_workflow_run_block( - workflow_run_block_id=workflow_run_block_id, - organization_id=context.organization_id if context else None, - status=status, - failure_reason=failure_reason, - ) + + await app.DATABASE.update_workflow_run_block( + workflow_run_block_id=workflow_run_block_id, + organization_id=context.organization_id if context else None, + status=status, + failure_reason=failure_reason, + output=final_output, + ) + await _record_output_parameter_value( context.workflow_run_id, context.workflow_id, @@ -1375,3 +1373,42 @@ def render_template(template: str, data: dict[str, Any] | None = None) -> str: template_data.update(workflow_run_context.values) return jinja_template.render(template_data) + + +# Non-task-based blocks +async def run_code( + code: str, + label: str | None = None, + parameters: list[PARAMETER_TYPE] | None = None, +) -> dict[str, Any]: + context = skyvern_context.ensure_context() + workflow_id = context.workflow_id + workflow_run_id = context.workflow_run_id + organization_id = context.organization_id + browser_session_id = context.browser_session_id + if not workflow_id: + raise Exception("Workflow ID is required") + if not workflow_run_id: + raise Exception("Workflow run ID is required") + if not organization_id: + raise Exception("Organization ID is required") + workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id) + if not workflow: + raise Exception("Workflow not found") + label = label or f"block_{uuid.uuid4()}" + output_parameter = workflow.get_output_parameter(label) + if not output_parameter: + raise Exception("Output parameter not found") + + code_block = CodeBlock( + code=code, + label=label, + parameters=parameters or [], + output_parameter=output_parameter, + ) + block_result = await code_block.execute_safe( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + browser_session_id=browser_session_id, + ) + return cast(dict[str, Any], block_result.output_parameter_value)