script gen run code block using the block interface (#3401)
This commit is contained in:
@@ -36,6 +36,7 @@ from skyvern.services.script_service import ( # noqa: E402
|
|||||||
generate_text, # noqa: E402
|
generate_text, # noqa: E402
|
||||||
login, # noqa: E402
|
login, # noqa: E402
|
||||||
render_template, # noqa: E402
|
render_template, # noqa: E402
|
||||||
|
run_code, # noqa: E402
|
||||||
run_script, # noqa: E402
|
run_script, # noqa: E402
|
||||||
run_task, # noqa: E402
|
run_task, # noqa: E402
|
||||||
wait, # noqa: E402
|
wait, # noqa: E402
|
||||||
@@ -53,6 +54,7 @@ __all__ = [
|
|||||||
"generate_text",
|
"generate_text",
|
||||||
"login",
|
"login",
|
||||||
"render_template",
|
"render_template",
|
||||||
|
"run_code",
|
||||||
"run_script",
|
"run_script",
|
||||||
"run_task",
|
"run_task",
|
||||||
"setup",
|
"setup",
|
||||||
|
|||||||
@@ -127,7 +127,9 @@ def _value(value: Any) -> cst.BaseExpression:
|
|||||||
"""Convert simple Python objects to CST expressions."""
|
"""Convert simple Python objects to CST expressions."""
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if "\n" in value:
|
if "\n" in value:
|
||||||
return cst.SimpleString('"""' + value.replace('"""', '\\"\\"\\"') + '"""')
|
# For multi-line strings, use repr() which handles all escaping properly
|
||||||
|
# This will use triple quotes when appropriate and escape them when needed
|
||||||
|
return cst.SimpleString(repr(value))
|
||||||
return cst.SimpleString(repr(value))
|
return cst.SimpleString(repr(value))
|
||||||
if isinstance(value, (int, float, bool)) or value is None:
|
if isinstance(value, (int, float, bool)) or value is None:
|
||||||
return cst.parse_expression(repr(value))
|
return cst.parse_expression(repr(value))
|
||||||
@@ -877,6 +879,47 @@ def _build_goto_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
|
|||||||
return cst.SimpleStatementLine([cst.Expr(cst.Await(call))])
|
return cst.SimpleStatementLine([cst.Expr(cst.Await(call))])
|
||||||
|
|
||||||
|
|
||||||
|
def _build_code_statement(block: dict[str, Any]) -> cst.SimpleStatementLine:
|
||||||
|
"""Build a skyvern.run_code statement."""
|
||||||
|
args = [
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("code"),
|
||||||
|
value=_value(block.get("code", "")),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("label"),
|
||||||
|
value=_value(block.get("label") or block.get("title") or f"block_{block.get('workflow_run_block_id')}"),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
cst.Arg(
|
||||||
|
keyword=cst.Name("parameters"),
|
||||||
|
value=_value(block.get("parameters", None)),
|
||||||
|
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
),
|
||||||
|
comma=cst.Comma(),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
call = cst.Call(
|
||||||
|
func=cst.Attribute(value=cst.Name("skyvern"), attr=cst.Name("run_code")),
|
||||||
|
args=args,
|
||||||
|
whitespace_before_args=cst.ParenthesizedWhitespace(
|
||||||
|
indent=True,
|
||||||
|
last_line=cst.SimpleWhitespace(INDENT),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cst.SimpleStatementLine([cst.Expr(cst.Await(call))])
|
||||||
|
|
||||||
|
|
||||||
def __build_base_task_statement(block_title: str, block: dict[str, Any]) -> list[cst.Arg]:
|
def __build_base_task_statement(block_title: str, block: dict[str, Any]) -> list[cst.Arg]:
|
||||||
args = [
|
args = [
|
||||||
cst.Arg(
|
cst.Arg(
|
||||||
@@ -986,6 +1029,8 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
|
|||||||
stmt = _build_for_loop_statement(block_title, block)
|
stmt = _build_for_loop_statement(block_title, block)
|
||||||
elif block_type == "goto_url":
|
elif block_type == "goto_url":
|
||||||
stmt = _build_goto_statement(block)
|
stmt = _build_goto_statement(block)
|
||||||
|
elif block_type == "code":
|
||||||
|
stmt = _build_code_statement(block)
|
||||||
else:
|
else:
|
||||||
# Default case for unknown block types
|
# Default case for unknown block types
|
||||||
stmt = cst.SimpleStatementLine([cst.Expr(cst.SimpleString(f"# Unknown block type: {block_type}"))])
|
stmt = cst.SimpleStatementLine([cst.Expr(cst.SimpleString(f"# Unknown block type: {block_type}"))])
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import hashlib
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
@@ -25,7 +26,8 @@ from skyvern.forge.sdk.core import skyvern_context
|
|||||||
from skyvern.forge.sdk.models import Step, StepStatus
|
from skyvern.forge.sdk.models import Step, StepStatus
|
||||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskOutput, TaskStatus
|
||||||
from skyvern.forge.sdk.workflow.models.block import TaskBlock
|
from skyvern.forge.sdk.workflow.models.block import CodeBlock, TaskBlock
|
||||||
|
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
||||||
from skyvern.schemas.runs import RunEngine
|
from skyvern.schemas.runs import RunEngine
|
||||||
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate
|
from skyvern.schemas.scripts import CreateScriptResponse, FileEncoding, FileNode, ScriptFileCreate
|
||||||
@@ -451,21 +453,17 @@ async def _update_workflow_block(
|
|||||||
|
|
||||||
task_output = TaskOutput.from_task(updated_task, downloaded_files)
|
task_output = TaskOutput.from_task(updated_task, downloaded_files)
|
||||||
final_output = task_output.model_dump()
|
final_output = task_output.model_dump()
|
||||||
await app.DATABASE.update_workflow_run_block(
|
|
||||||
workflow_run_block_id=workflow_run_block_id,
|
|
||||||
organization_id=context.organization_id if context else None,
|
|
||||||
status=status,
|
|
||||||
failure_reason=failure_reason,
|
|
||||||
output=final_output,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
final_output = None
|
final_output = None
|
||||||
await app.DATABASE.update_workflow_run_block(
|
|
||||||
workflow_run_block_id=workflow_run_block_id,
|
await app.DATABASE.update_workflow_run_block(
|
||||||
organization_id=context.organization_id if context else None,
|
workflow_run_block_id=workflow_run_block_id,
|
||||||
status=status,
|
organization_id=context.organization_id if context else None,
|
||||||
failure_reason=failure_reason,
|
status=status,
|
||||||
)
|
failure_reason=failure_reason,
|
||||||
|
output=final_output,
|
||||||
|
)
|
||||||
|
|
||||||
await _record_output_parameter_value(
|
await _record_output_parameter_value(
|
||||||
context.workflow_run_id,
|
context.workflow_run_id,
|
||||||
context.workflow_id,
|
context.workflow_id,
|
||||||
@@ -1375,3 +1373,42 @@ def render_template(template: str, data: dict[str, Any] | None = None) -> str:
|
|||||||
template_data.update(workflow_run_context.values)
|
template_data.update(workflow_run_context.values)
|
||||||
|
|
||||||
return jinja_template.render(template_data)
|
return jinja_template.render(template_data)
|
||||||
|
|
||||||
|
|
||||||
|
# Non-task-based blocks
|
||||||
|
async def run_code(
|
||||||
|
code: str,
|
||||||
|
label: str | None = None,
|
||||||
|
parameters: list[PARAMETER_TYPE] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
context = skyvern_context.ensure_context()
|
||||||
|
workflow_id = context.workflow_id
|
||||||
|
workflow_run_id = context.workflow_run_id
|
||||||
|
organization_id = context.organization_id
|
||||||
|
browser_session_id = context.browser_session_id
|
||||||
|
if not workflow_id:
|
||||||
|
raise Exception("Workflow ID is required")
|
||||||
|
if not workflow_run_id:
|
||||||
|
raise Exception("Workflow run ID is required")
|
||||||
|
if not organization_id:
|
||||||
|
raise Exception("Organization ID is required")
|
||||||
|
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
|
||||||
|
if not workflow:
|
||||||
|
raise Exception("Workflow not found")
|
||||||
|
label = label or f"block_{uuid.uuid4()}"
|
||||||
|
output_parameter = workflow.get_output_parameter(label)
|
||||||
|
if not output_parameter:
|
||||||
|
raise Exception("Output parameter not found")
|
||||||
|
|
||||||
|
code_block = CodeBlock(
|
||||||
|
code=code,
|
||||||
|
label=label,
|
||||||
|
parameters=parameters or [],
|
||||||
|
output_parameter=output_parameter,
|
||||||
|
)
|
||||||
|
block_result = await code_block.execute_safe(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
)
|
||||||
|
return cast(dict[str, Any], block_result.output_parameter_value)
|
||||||
|
|||||||
Reference in New Issue
Block a user