generate GeneratedWorkflowParameters (#3264)

This commit is contained in:
Shuchang Zheng
2025-08-21 15:42:34 -07:00
committed by GitHub
parent 988416829f
commit 2a62dc08aa
7 changed files with 504 additions and 64 deletions

View File

@@ -25,6 +25,10 @@ import structlog
from libcst import Attribute, Call, Dict, DictElement, FunctionDef, Name, Param
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_workflow_parameters import (
generate_workflow_parameters_schema,
hydrate_input_text_actions_with_field_names,
)
from skyvern.forge import app
from skyvern.webeye.actions.action_types import ActionType
@@ -61,6 +65,7 @@ ACTIONS_WITH_XPATH = [
]
INDENT = " " * 4
DOUBLE_INDENT = " " * 8
def _safe_name(label: str) -> str:
@@ -97,6 +102,57 @@ def _value(value: Any) -> cst.BaseExpression:
return cst.SimpleString(repr(str(value)))
def _generate_text_call(text_value: str, intention: str, parameter_key: str) -> cst.BaseExpression:
"""Create a generate_text function call CST expression."""
return cst.Await(
expression=cst.Call(
func=cst.Name("generate_text"),
whitespace_before_args=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
args=[
# First positional argument: context.generated_parameters['parameter_key']
cst.Arg(
value=cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("generated_parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))],
),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
),
# intention keyword argument
cst.Arg(
keyword=cst.Name("intention"),
value=_value(intention),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
),
# data keyword argument
cst.Arg(
keyword=cst.Name("data"),
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
comma=cst.Comma(),
),
],
)
)
# --------------------------------------------------------------------- #
# 2. utility builders #
# --------------------------------------------------------------------- #
@@ -177,10 +233,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
)
if method in ["type", "fill"]:
# Get intention from action
intention = act.get("intention") or act.get("reasoning") or ""
# Use generate_text call if field_name is available, otherwise fallback to direct value
if act.get("field_name"):
text_value = _generate_text_call(
text_value=act["text"], intention=intention, parameter_key=act["field_name"]
)
else:
text_value = _value(act["text"])
args.append(
cst.Arg(
keyword=cst.Name("text"),
value=_value(act["text"]),
value=text_value,
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
@@ -212,7 +279,7 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
elif method == "extract":
args.append(
cst.Arg(
keyword=cst.Name("data_extraction_goal"),
keyword=cst.Name("prompt"),
value=_value(act["data_extraction_goal"]),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
@@ -309,8 +376,8 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
"""
class WorkflowParameters(BaseModel):
ein_info: str
company_name: str
param1: str
param2: str
...
"""
ann_lines: list[cst.BaseStatement] = []
@@ -319,7 +386,6 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
if p["parameter_type"] != "workflow":
continue
# ein_info: str
ann = cst.AnnAssign(
target=cst.Name(p["key"]),
annotation=cst.Annotation(cst.Name("str")),
@@ -337,21 +403,24 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
)
def _build_cached_params(values: dict[str, Any]) -> cst.SimpleStatementLine:
def _build_generated_model_from_schema(schema_code: str) -> cst.ClassDef | None:
"""
Make a CST for:
cached_parameters = WorkflowParameters(ein_info="...", ...)
Parse the generated schema code and return a ClassDef, or None if parsing fails.
"""
call = cst.Call(
func=cst.Name("WorkflowParameters"),
args=[cst.Arg(keyword=cst.Name(k), value=_value(v)) for k, v in values.items()],
)
try:
# Parse the schema code and extract just the class definition
parsed_module = cst.parse_module(schema_code)
assign = cst.Assign(
targets=[cst.AssignTarget(cst.Name("cached_parameters"))],
value=call,
)
return cst.SimpleStatementLine([assign])
# Find the GeneratedWorkflowParameters class in the parsed module
for node in parsed_module.body:
if isinstance(node, cst.ClassDef) and node.name.value == "GeneratedWorkflowParameters":
return node
# If no class found, return None
return None
except Exception as e:
LOG.warning("Failed to parse generated schema code", error=str(e))
return None
# --------------------------------------------------------------------- #
@@ -804,7 +873,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
cst.parse_statement(
"parameters = parameters.model_dump() if isinstance(parameters, WorkflowParameters) else parameters"
),
cst.parse_statement("page, context = await skyvern.setup(parameters)"),
cst.parse_statement("page, context = await skyvern.setup(parameters, GeneratedWorkflowParameters)"),
]
for block in blocks:
@@ -867,8 +936,27 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
params=[
Param(
name=cst.Name("parameters"),
annotation=cst.Annotation(cst.Name("WorkflowParameters")),
default=cst.Name("cached_parameters"),
annotation=cst.Annotation(
cst.BinaryOperation(
left=cst.Name("WorkflowParameters"),
operator=cst.BitOr(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
),
right=cst.Subscript(
value=cst.Name("dict"),
slice=[
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("str")),
comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
),
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("Any")),
),
],
),
)
),
whitespace_after_param=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
@@ -948,11 +1036,24 @@ async def generate_workflow_script(
imports: list[cst.BaseStatement] = [
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]),
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("pydantic"))])]),
cst.SimpleStatementLine(
[
cst.ImportFrom(
module=cst.Name("typing"),
names=[
cst.ImportAlias(cst.Name("Any")),
],
)
]
),
cst.SimpleStatementLine(
[
cst.ImportFrom(
module=cst.Name("pydantic"),
names=[cst.ImportAlias(cst.Name("BaseModel"))],
names=[
cst.ImportAlias(cst.Name("BaseModel")),
cst.ImportAlias(cst.Name("Field")),
],
)
]
),
@@ -964,15 +1065,20 @@ async def generate_workflow_script(
names=[
cst.ImportAlias(cst.Name("RunContext")),
cst.ImportAlias(cst.Name("SkyvernPage")),
cst.ImportAlias(cst.Name("generate_text")),
],
)
]
),
]
# --- generate schema and hydrate actions ---------------------------
generated_schema, field_mappings = await generate_workflow_parameters_schema(actions_by_task)
actions_by_task = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
# --- class + cached params -----------------------------------------
model_cls = _build_model(workflow)
cached_params_stmt = _build_cached_params(workflow_run_request.get("parameters", {}))
generated_model_cls = _build_generated_model_from_schema(generated_schema)
# --- blocks ---------------------------------------------------------
block_fns = []
@@ -1008,17 +1114,29 @@ async def generate_workflow_script(
# --- runner ---------------------------------------------------------
run_fn = _build_run_fn(blocks, workflow_run_request)
module = cst.Module(
body=[
*imports,
cst.EmptyLine(),
cst.EmptyLine(),
model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
cached_params_stmt,
cst.EmptyLine(),
cst.EmptyLine(),
# Build module body with optional generated model class
module_body = [
*imports,
cst.EmptyLine(),
cst.EmptyLine(),
model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
]
# Add generated model class if available
if generated_model_cls:
module_body.extend(
[
generated_model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
]
)
# Continue with the rest of the module
module_body.extend(
[
*block_fns,
cst.EmptyLine(),
cst.EmptyLine(),
@@ -1029,6 +1147,8 @@ async def generate_workflow_script(
]
)
module = cst.Module(body=module_body)
with open(file_name, "w") as f:
f.write(module.code)
return module.code

View File

@@ -0,0 +1,193 @@
"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""
from typing import Any, Dict, List, Tuple
import structlog
from pydantic import BaseModel
from skyvern.forge import app
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType
LOG = structlog.get_logger(__name__)
# Initialize prompt engine
prompt_engine = PromptEngine("skyvern")
class GeneratedFieldMapping(BaseModel):
"""Mapping of action indices to field names."""
field_mappings: Dict[str, str]
schema_fields: Dict[str, Dict[str, str]]
async def generate_workflow_parameters_schema(
actions_by_task: Dict[str, List[Dict[str, Any]]],
) -> Tuple[str, Dict[str, str]]:
"""
Generate a GeneratedWorkflowParameters Pydantic schema based on input_text actions.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
Returns:
Tuple of (schema_code, field_mappings) where:
- schema_code: Python code for the GeneratedWorkflowParameters class
- field_mappings: Dictionary mapping action indices to field names for hydration
"""
# Extract all input_text actions
input_actions = []
action_index_map = {}
action_counter = 1
for task_id, actions in actions_by_task.items():
for action in actions:
if action.get("action_type") == ActionType.INPUT_TEXT:
input_actions.append(
{
"text": action.get("text", ""),
"intention": action.get("intention", ""),
"task_id": task_id,
"action_id": action.get("action_id", ""),
}
)
action_index_map[f"action_index_{action_counter}"] = {
"task_id": task_id,
"action_id": action.get("action_id", ""),
}
action_counter += 1
if not input_actions:
LOG.warning("No input_text actions found in workflow run")
return _generate_empty_schema(), {}
# Generate field names using LLM
try:
field_mapping = await _generate_field_names_with_llm(input_actions)
# Generate the Pydantic schema code
schema_code = _generate_pydantic_schema(field_mapping.schema_fields)
# Create field mappings for action hydration
action_field_mappings = {}
for action_idx, field_name in field_mapping.field_mappings.items():
if action_idx in action_index_map:
action_info = action_index_map[action_idx]
key = f"{action_info['task_id']}:{action_info['action_id']}"
action_field_mappings[key] = field_name
return schema_code, action_field_mappings
except Exception as e:
LOG.error("Failed to generate workflow parameters schema", error=str(e), exc_info=True)
return _generate_empty_schema(), {}
async def _generate_field_names_with_llm(input_actions: List[Dict[str, Any]]) -> GeneratedFieldMapping:
"""
Use LLM to generate field names from input actions.
Args:
input_actions: List of input_text action dictionaries
Returns:
GeneratedFieldMapping with field mappings and schema definitions
"""
prompt = prompt_engine.load_prompt(template="generate-workflow-parameters", input_actions=input_actions)
response = await app.LLM_API_HANDLER(prompt=prompt, prompt_name="generate-workflow-parameters")
return GeneratedFieldMapping.model_validate(response)
def _generate_pydantic_schema(schema_fields: Dict[str, Dict[str, str]]) -> str:
"""
Generate Pydantic schema code from field definitions.
Args:
schema_fields: Dictionary of field names to their type and description
Returns:
Python code string for the GeneratedWorkflowParameters class
"""
if not schema_fields:
return _generate_empty_schema()
lines = [
"from pydantic import BaseModel, Field",
"",
"",
"class GeneratedWorkflowParameters(BaseModel):",
' """Generated schema representing all input_text action values from the workflow run."""',
"",
]
for field_name, field_info in schema_fields.items():
field_type = field_info.get("type", "str")
description = field_info.get("description", f"Value for {field_name}")
# Escape quotes in description
description = description.replace('"', '\\"')
lines.append(f' {field_name}: {field_type} = Field(description="{description}", default="")')
return "\n".join(lines)
def _generate_empty_schema() -> str:
"""Generate an empty schema when no input_text actions are found."""
return '''from pydantic import BaseModel
class GeneratedWorkflowParameters(BaseModel):
"""Generated schema representing all input_text action values from the workflow run."""
pass
'''
def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names
Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {}
for task_id, actions in actions_by_task.items():
updated_actions = []
for action in actions:
action_copy = action.copy()
if action.get("action_type") == ActionType.INPUT_TEXT:
action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}"
if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy)
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task

View File

@@ -1,11 +1,39 @@
from typing import Any
from pydantic import BaseModel
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage
async def setup(parameters: dict[str, Any], run_id: str | None = None) -> tuple[SkyvernPage, RunContext]:
async def setup(
parameters: dict[str, Any], generated_parameter_cls: type[BaseModel] | None = None
) -> tuple[SkyvernPage, RunContext]:
skyvern_page = await SkyvernPage.create()
run_context = RunContext(parameters=parameters, page=skyvern_page)
run_context = RunContext(
parameters=parameters,
page=skyvern_page,
# TODO: generate all parameters with llm here - then we can skip generating input text one by one in the fill/type methods
generated_parameters=generated_parameter_cls().model_dump() if generated_parameter_cls else None,
)
script_run_context_manager.set_run_context(run_context)
return skyvern_page, run_context
# async def transform_parameters(parameters: dict[str, Any] | BaseModel | None = None, generated_parameter_cls: type[BaseModel] | None = None) -> dict[str, Any] | None:
# if parameters is None:
# return None
# if generated_parameter_cls:
# if isinstance(parameters, dict):
# # TODO: use llm to generate
# return generated_parameter_cls.model_validate(parameters)
# if isinstance(parameters, BaseModel):
# return parameters
# return generated_parameter_cls.model_validate(parameters)
# else:
# if isinstance(parameters, dict):
# return parameters
# if isinstance(parameters, BaseModel):
# return parameters.model_dump()
# return parameters

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import copy
import json
from dataclasses import dataclass
from datetime import datetime, timezone
@@ -328,29 +329,8 @@ class SkyvernPage:
If the prompt generation or parsing fails for any reason we fall back to
inputting the originally supplied ``text``.
"""
new_text = text
if intention and data:
try:
# Build the element tree of the current page for the prompt
skyvern_context.ensure_context()
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt(
template="script-generation-input-text-generatiion",
intention=intention,
data=payload_str,
)
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
prompt=script_generation_input_text_prompt,
prompt_name="script-generation-input-text-generatiion",
)
new_text = json_response.get("answer", text) or text
except Exception:
# If anything goes wrong, fall back to the original text
new_text = text
locator = self.page.locator(f"xpath={xpath}")
await handler_utils.input_sequentially(locator, new_text, timeout=timeout)
await handler_utils.input_sequentially(locator, text, timeout=timeout)
@action_wrap(ActionType.UPLOAD_FILE)
async def upload_file(
@@ -420,8 +400,8 @@ class SkyvernPage:
@action_wrap(ActionType.EXTRACT)
async def extract(
self,
data_extraction_goal: str,
data_schema: dict[str, Any] | list | str | None = None,
prompt: str,
schema: dict[str, Any] | list | str | None = None,
error_code_mapping: dict[str, str] | None = None,
intention: str | None = None,
data: str | dict[str, Any] | None = None,
@@ -436,8 +416,8 @@ class SkyvernPage:
prompt_engine=prompt_engine,
template_name="extract-information",
html_need_skyvern_attrs=False,
data_extraction_goal=data_extraction_goal,
extracted_information_schema=data_schema,
data_extraction_goal=prompt,
extracted_information_schema=schema,
current_url=scraped_page_refreshed.url,
extracted_text=scraped_page_refreshed.extracted_text,
error_code_mapping_str=(json.dumps(error_code_mapping) if error_code_mapping else None),
@@ -509,8 +489,14 @@ class SkyvernPage:
class RunContext:
def __init__(self, parameters: dict[str, Any], page: SkyvernPage) -> None:
self.parameters = parameters
def __init__(
self, parameters: dict[str, Any], page: SkyvernPage, generated_parameters: dict[str, Any] | None = None
) -> None:
self.original_parameters = parameters
self.generated_parameters = generated_parameters
self.parameters = copy.deepcopy(parameters)
# if generated_parameters:
# self.parameters.update(generated_parameters)
self.page = page
self.trace: list[ActionCall] = []
self.prompt: str | None = None