script gen - support skyvern.loop & cleaner interfaces for generated code (no need to pass context.parameters, implicit template rendering) (#3542)

This commit is contained in:
Shuchang Zheng
2025-09-26 23:27:29 -07:00
committed by GitHub
parent 8c54475fda
commit 90096bc453
7 changed files with 336 additions and 161 deletions

View File

@@ -2,12 +2,11 @@ import asyncio
import base64
import hashlib
import importlib.util
import json
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, cast
from typing import Any, AsyncGenerator, Callable, Sequence, cast
import libcst as cst
import structlog
@@ -21,7 +20,6 @@ from skyvern.core.script_generations.generate_script import _build_block_fn, cre
from skyvern.core.script_generations.skyvern_page import script_run_context_manager
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.models import Step, StepStatus
@@ -35,6 +33,7 @@ from skyvern.forge.sdk.workflow.models.block import (
FileDownloadBlock,
FileParserBlock,
FileUploadBlock,
ForLoopBlock,
HttpRequestBlock,
LoginBlock,
SendEmailBlock,
@@ -52,6 +51,20 @@ LOG = structlog.get_logger()
jinja_sandbox_env = SandboxedEnvironment()
class SkyvernLoopItem:
def __init__(
self,
index: int,
value: Any,
):
self.current_index = index
self.current_value = value
self.current_item = value
def __repr__(self) -> str:
return f"SkyvernLoopItem(current_value={self.current_value}, current_index={self.current_index})"
async def build_file_tree(
files: list[ScriptFileCreate],
organization_id: str,
@@ -363,6 +376,7 @@ async def _create_workflow_block_run_and_task(
prompt: str | None = None,
schema: dict[str, Any] | list | str | None = None,
url: str | None = None,
label: str | None = None,
) -> tuple[str | None, str | None, str | None]:
"""
Create a workflow block run and optionally a task if workflow_run_id is available in context.
@@ -374,24 +388,34 @@ async def _create_workflow_block_run_and_task(
workflow_run_id = context.workflow_run_id
organization_id = context.organization_id
# if there's a parent_workflow_run_block_id and loop_metadata, update_block_metadata
if context.parent_workflow_run_block_id and context.loop_metadata and label:
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
workflow_run_context.update_block_metadata(label, context.loop_metadata)
workflow_run_block = await app.DATABASE.create_workflow_run_block(
workflow_run_id=workflow_run_id,
parent_workflow_run_block_id=context.parent_workflow_run_block_id,
organization_id=organization_id,
block_type=block_type,
label=label,
)
workflow_run_block_id = workflow_run_block.workflow_run_block_id
try:
# Create workflow run block with appropriate parameters based on block type
# TODO: support engine in the future
engine = None
workflow_run_block = await app.DATABASE.create_workflow_run_block(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
block_type=block_type,
engine=engine,
)
workflow_run_block_id = workflow_run_block.workflow_run_block_id
task_id = None
step_id = None
# Create task for task-based blocks
if block_type in SCRIPT_TASK_BLOCKS:
# Create task
if prompt:
prompt = _render_template_with_label(prompt, label)
if url:
url = _render_template_with_label(url, label)
task = await app.DATABASE.create_task(
# fix HACK: changed the type of url to str | None to support None url. url is not used in the script right now.
url=url or "",
@@ -1107,6 +1131,7 @@ async def run_task(
block_type=BlockType.TASK,
prompt=prompt,
url=url,
label=cache_key,
)
# set the prompt in the RunContext
context = skyvern_context.ensure_context()
@@ -1155,6 +1180,7 @@ async def run_task(
)
await task_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1180,6 +1206,7 @@ async def download(
block_type=BlockType.FILE_DOWNLOAD,
prompt=prompt,
url=url,
label=cache_key,
)
# set the prompt in the RunContext
context = skyvern_context.ensure_context()
@@ -1228,6 +1255,7 @@ async def download(
)
await file_download_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1251,6 +1279,7 @@ async def action(
block_type=BlockType.ACTION,
prompt=prompt,
url=url,
label=cache_key,
)
# set the prompt in the RunContext
context = skyvern_context.ensure_context()
@@ -1297,6 +1326,7 @@ async def action(
)
await action_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1320,6 +1350,7 @@ async def login(
block_type=BlockType.LOGIN,
prompt=prompt,
url=url,
label=cache_key,
)
# set the prompt in the RunContext
context = skyvern_context.ensure_context()
@@ -1365,6 +1396,7 @@ async def login(
)
await login_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1390,6 +1422,7 @@ async def extract(
prompt=prompt,
schema=schema,
url=url,
label=cache_key,
)
# set the prompt in the RunContext
context = skyvern_context.ensure_context()
@@ -1437,15 +1470,16 @@ async def extract(
)
block_result = await extraction_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
return block_result.output_parameter_value
async def wait(seconds: int) -> None:
async def wait(seconds: int, label: str | None = None) -> None:
# Auto-create workflow block run if workflow_run_id is available (wait block doesn't create tasks)
workflow_run_block_id, _, _ = await _create_workflow_block_run_and_task(block_type=BlockType.WAIT)
workflow_run_block_id, _, _ = await _create_workflow_block_run_and_task(block_type=BlockType.WAIT, label=label)
try:
await asyncio.sleep(seconds)
@@ -1507,36 +1541,32 @@ async def run_script(
raise Exception(f"No 'run_workflow' function found in {path}")
async def generate_text(
text: str | None = None,
intention: str | None = None,
data: dict[str, Any] | None = None,
) -> str:
if text:
return text
new_text = text or ""
if intention and data:
try:
context = skyvern_context.ensure_context()
prompt = context.prompt
# Build the element tree of the current page for the prompt
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt(
template="script-generation-input-text-generatiion",
intention=intention,
data=payload_str,
goal=prompt,
)
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
prompt=script_generation_input_text_prompt,
prompt_name="script-generation-input-text-generatiion",
organization_id=context.organization_id,
)
new_text = json_response.get("answer", new_text)
except Exception:
LOG.exception("Failed to generate text for script")
raise
return new_text
def _render_template_with_label(template: str, label: str | None = None) -> str:
template_data = {}
context = skyvern_context.current()
if context and context.workflow_run_id and label:
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(context.workflow_run_id)
block_reference_data: dict[str, Any] = workflow_run_context.get_block_metadata(label)
template_data = workflow_run_context.values.copy()
if label in template_data:
current_value = template_data[label]
if isinstance(current_value, dict):
block_reference_data.update(current_value)
else:
LOG.warning(
f"Script service: Parameter {label} has a registered reference value, going to overwrite it by block metadata"
)
template_data[label] = block_reference_data
# inject the forloop metadata as global variables
if "current_index" in block_reference_data:
template_data["current_index"] = block_reference_data["current_index"]
if "current_item" in block_reference_data:
template_data["current_item"] = block_reference_data["current_item"]
if "current_value" in block_reference_data:
template_data["current_value"] = block_reference_data["current_value"]
return render_template(template, data=template_data)
def render_template(template: str, data: dict[str, Any] | None = None) -> str:
@@ -1545,16 +1575,17 @@ def render_template(template: str, data: dict[str, Any] | None = None) -> str:
TODO: complete this function so that block code shares the same template rendering logic
"""
template_data = data or {}
template_data = data.copy() if data else {}
jinja_template = jinja_sandbox_env.from_string(template)
context = skyvern_context.current()
if context and context.workflow_run_id:
workflow_run_id = context.workflow_run_id
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
template_data.update(workflow_run_context.values)
if template in template_data:
return template_data[template]
if context:
template_data.update(context.script_run_parameters)
if context.workflow_run_id:
workflow_run_id = context.workflow_run_id
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
template_data.update(workflow_run_context.values)
if template in template_data:
return template_data[template]
return jinja_template.render(template_data)
@@ -1571,6 +1602,7 @@ def render_list(template: str, data: dict[str, Any] | None = None) -> list[str]:
## Non-task-based block helpers
@dataclass
class BlockValidationOutput:
context: skyvern_context.SkyvernContext
label: str
output_parameter: OutputParameter
workflow: Workflow
@@ -1596,6 +1628,9 @@ async def _validate_and_get_output_parameter(label: str | None = None) -> BlockV
if not workflow:
raise Exception("Workflow not found")
label = label or f"block_{uuid.uuid4()}"
if context.loop_metadata:
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
workflow_run_context.update_block_metadata(label, context.loop_metadata)
output_parameter = workflow.get_output_parameter(label)
if not output_parameter:
# NOT sure if this is legit hack to create output parameter like this
@@ -1608,6 +1643,7 @@ async def _validate_and_get_output_parameter(label: str | None = None) -> BlockV
parameter_type=ParameterType.OUTPUT,
)
return BlockValidationOutput(
context=context,
label=label,
output_parameter=output_parameter,
workflow=workflow,
@@ -1632,6 +1668,7 @@ async def run_code(
)
block_result = await code_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1652,6 +1689,22 @@ async def upload_file(
path: str | None = None,
) -> None:
block_validation_output = await _validate_and_get_output_parameter(label)
if s3_bucket:
s3_bucket = _render_template_with_label(s3_bucket, label)
if aws_access_key_id:
aws_access_key_id = _render_template_with_label(aws_access_key_id, label)
if aws_secret_access_key:
aws_secret_access_key = _render_template_with_label(aws_secret_access_key, label)
if region_name:
region_name = _render_template_with_label(region_name, label)
if azure_storage_account_name:
azure_storage_account_name = _render_template_with_label(azure_storage_account_name, label)
if azure_storage_account_key:
azure_storage_account_key = _render_template_with_label(azure_storage_account_key, label)
if azure_blob_container_name:
azure_blob_container_name = _render_template_with_label(azure_blob_container_name, label)
if path:
path = _render_template_with_label(path, label)
file_upload_block = FileUploadBlock(
label=block_validation_output.label,
output_parameter=block_validation_output.output_parameter,
@@ -1668,6 +1721,7 @@ async def upload_file(
)
await file_upload_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1675,7 +1729,7 @@ async def upload_file(
async def send_email(
sender: str,
recipients: list[str],
recipients: list[str] | str,
subject: str,
body: str,
file_attachments: list[str] = [],
@@ -1683,6 +1737,11 @@ async def send_email(
parameters: list[PARAMETER_TYPE] | None = None,
) -> None:
block_validation_output = await _validate_and_get_output_parameter(label)
sender = _render_template_with_label(sender, label)
if isinstance(recipients, str):
recipients = render_list(_render_template_with_label(recipients, label))
subject = _render_template_with_label(subject, label)
body = _render_template_with_label(body, label)
workflow = block_validation_output.workflow
smtp_host_parameter = workflow.get_parameter("smtp_host")
smtp_port_parameter = workflow.get_parameter("smtp_port")
@@ -1706,6 +1765,7 @@ async def send_email(
)
await send_email_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1719,6 +1779,7 @@ async def parse_file(
parameters: list[PARAMETER_TYPE] | None = None,
) -> None:
block_validation_output = await _validate_and_get_output_parameter(label)
file_url = _render_template_with_label(file_url, label)
file_parser_block = FileParserBlock(
file_url=file_url,
file_type=file_type,
@@ -1729,6 +1790,7 @@ async def parse_file(
)
await file_parser_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1745,6 +1807,8 @@ async def http_request(
parameters: list[PARAMETER_TYPE] | None = None,
) -> None:
block_validation_output = await _validate_and_get_output_parameter(label)
method = _render_template_with_label(method, label)
url = _render_template_with_label(url, label)
http_request_block = HttpRequestBlock(
method=method,
url=url,
@@ -1758,6 +1822,7 @@ async def http_request(
)
await http_request_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1769,6 +1834,7 @@ async def goto(
parameters: list[PARAMETER_TYPE] | None = None,
) -> None:
block_validation_output = await _validate_and_get_output_parameter(label)
url = _render_template_with_label(url, label)
goto_url_block = UrlBlock(
url=url,
label=block_validation_output.label,
@@ -1777,6 +1843,7 @@ async def goto(
)
await goto_url_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
@@ -1789,6 +1856,7 @@ async def prompt(
parameters: list[PARAMETER_TYPE] | None = None,
) -> dict[str, Any] | list | str | None:
block_validation_output = await _validate_and_get_output_parameter(label)
prompt = _render_template_with_label(prompt, label)
prompt_block = TextPromptBlock(
prompt=prompt,
json_schema=schema,
@@ -1798,7 +1866,119 @@ async def prompt(
)
result = await prompt_block.execute_safe(
workflow_run_id=block_validation_output.workflow_run_id,
parent_workflow_run_block_id=block_validation_output.context.parent_workflow_run_block_id,
organization_id=block_validation_output.organization_id,
browser_session_id=block_validation_output.browser_session_id,
)
return result.output_parameter_value
async def loop(
values: Sequence[Any] | str,
complete_if_empty: bool = False,
label: str | None = None,
) -> AsyncGenerator[SkyvernLoopItem, None]:
workflow_run_block_id, _, _ = await _create_workflow_block_run_and_task(block_type=BlockType.FOR_LOOP, label=label)
# process values:
loop_variable_reference = None
loop_values = None
if isinstance(values, list):
loop_values = values
elif isinstance(values, str):
loop_variable_reference = values
else:
raise ValueError(f"Invalid values type: {type(values)}")
# step. build the ForLoopBlock instance
block_validation_output = await _validate_and_get_output_parameter(label)
loop_block = ForLoopBlock(
label=block_validation_output.label,
output_parameter=block_validation_output.output_parameter,
loop_variable_reference=loop_variable_reference,
loop_blocks=[],
complete_if_empty=complete_if_empty,
)
workflow_run_id = block_validation_output.workflow_run_id
organization_id = block_validation_output.organization_id
if not loop_values:
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
if workflow_run_block_id:
loop_values = await loop_block.get_values_from_loop_variable_reference(
workflow_run_context=workflow_run_context,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
if not loop_values:
# step 3. if loop_values is empty, record empty output parameter value
LOG.info(
"script service: No loop values found, terminating block",
block_type=BlockType.FOR_LOOP,
workflow_run_id=workflow_run_id,
complete_if_empty=complete_if_empty,
)
await loop_block.record_output_parameter_value(workflow_run_context, workflow_run_id, [])
# step 4. build response (success/failure) given the complete_if_empty value
if complete_if_empty:
await loop_block.build_block_result(
success=True,
failure_reason=None,
output_parameter_value=[],
status=BlockStatus.completed,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
return
else:
await loop_block.build_block_result(
success=False,
failure_reason="No iterable value found for the loop block",
status=BlockStatus.terminated,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
raise Exception("No iterable value found for the loop block")
# register the loop in the global context
block_validation_output.context.parent_workflow_run_block_id = workflow_run_block_id
block_validation_output.context.loop_output_values = []
# step 5. start the loop
try:
for index, value in enumerate(loop_values):
# register current_value, current_item and current_index in workflow run context
loop_metadata = {
"current_index": index,
"current_value": value,
"current_item": value,
}
block_validation_output.context.loop_metadata = loop_metadata
workflow_run_context.update_block_metadata(block_validation_output.label, loop_metadata)
# Build the SkyvernLoopItem for this loop
yield SkyvernLoopItem(index, value)
# build success output
if workflow_run_block_id:
await _update_workflow_block(
workflow_run_block_id,
BlockStatus.completed,
output=block_validation_output.context.loop_output_values,
label=label,
)
except Exception as e:
# build failure output
if workflow_run_block_id:
await _update_workflow_block(
workflow_run_block_id,
BlockStatus.failed,
failure_reason=str(e),
output=block_validation_output.context.loop_output_values,
label=label,
)
raise e
finally:
block_validation_output.context.parent_workflow_run_block_id = None
block_validation_output.context.loop_metadata = None
block_validation_output.context.loop_output_values = None