generate GeneratedWorkflowParameters (#3264)
This commit is contained in:
@@ -33,6 +33,7 @@ from skyvern.services.script_service import ( # noqa: E402
|
||||
action, # noqa: E402
|
||||
download, # noqa: E402
|
||||
extract, # noqa: E402
|
||||
generate_text, # noqa: E402
|
||||
login, # noqa: E402
|
||||
run_script, # noqa: E402
|
||||
run_task, # noqa: E402
|
||||
@@ -48,6 +49,7 @@ __all__ = [
|
||||
"cached",
|
||||
"download",
|
||||
"extract",
|
||||
"generate_text",
|
||||
"login",
|
||||
"run_script",
|
||||
"run_task",
|
||||
|
||||
@@ -25,6 +25,10 @@ import structlog
|
||||
from libcst import Attribute, Call, Dict, DictElement, FunctionDef, Name, Param
|
||||
|
||||
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
|
||||
from skyvern.core.script_generations.generate_workflow_parameters import (
|
||||
generate_workflow_parameters_schema,
|
||||
hydrate_input_text_actions_with_field_names,
|
||||
)
|
||||
from skyvern.forge import app
|
||||
from skyvern.webeye.actions.action_types import ActionType
|
||||
|
||||
@@ -61,6 +65,7 @@ ACTIONS_WITH_XPATH = [
|
||||
]
|
||||
|
||||
INDENT = " " * 4
|
||||
DOUBLE_INDENT = " " * 8
|
||||
|
||||
|
||||
def _safe_name(label: str) -> str:
|
||||
@@ -97,6 +102,57 @@ def _value(value: Any) -> cst.BaseExpression:
|
||||
return cst.SimpleString(repr(str(value)))
|
||||
|
||||
|
||||
def _generate_text_call(text_value: str, intention: str, parameter_key: str) -> cst.BaseExpression:
|
||||
"""Create a generate_text function call CST expression."""
|
||||
return cst.Await(
|
||||
expression=cst.Call(
|
||||
func=cst.Name("generate_text"),
|
||||
whitespace_before_args=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
|
||||
),
|
||||
args=[
|
||||
# First positional argument: context.generated_parameters['parameter_key']
|
||||
cst.Arg(
|
||||
value=cst.Subscript(
|
||||
value=cst.Attribute(
|
||||
value=cst.Name("context"),
|
||||
attr=cst.Name("generated_parameters"),
|
||||
),
|
||||
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))],
|
||||
),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
|
||||
),
|
||||
),
|
||||
# intention keyword argument
|
||||
cst.Arg(
|
||||
keyword=cst.Name("intention"),
|
||||
value=_value(intention),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
|
||||
),
|
||||
),
|
||||
# data keyword argument
|
||||
cst.Arg(
|
||||
keyword=cst.Name("data"),
|
||||
value=cst.Attribute(
|
||||
value=cst.Name("context"),
|
||||
attr=cst.Name("parameters"),
|
||||
),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
comma=cst.Comma(),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# 2. utility builders #
|
||||
# --------------------------------------------------------------------- #
|
||||
@@ -177,10 +233,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
|
||||
)
|
||||
|
||||
if method in ["type", "fill"]:
|
||||
# Get intention from action
|
||||
intention = act.get("intention") or act.get("reasoning") or ""
|
||||
|
||||
# Use generate_text call if field_name is available, otherwise fallback to direct value
|
||||
if act.get("field_name"):
|
||||
text_value = _generate_text_call(
|
||||
text_value=act["text"], intention=intention, parameter_key=act["field_name"]
|
||||
)
|
||||
else:
|
||||
text_value = _value(act["text"])
|
||||
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("text"),
|
||||
value=_value(act["text"]),
|
||||
value=text_value,
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
@@ -212,7 +279,7 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
|
||||
elif method == "extract":
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("data_extraction_goal"),
|
||||
keyword=cst.Name("prompt"),
|
||||
value=_value(act["data_extraction_goal"]),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
@@ -309,8 +376,8 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
|
||||
def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
|
||||
"""
|
||||
class WorkflowParameters(BaseModel):
|
||||
ein_info: str
|
||||
company_name: str
|
||||
param1: str
|
||||
param2: str
|
||||
...
|
||||
"""
|
||||
ann_lines: list[cst.BaseStatement] = []
|
||||
@@ -319,7 +386,6 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
|
||||
if p["parameter_type"] != "workflow":
|
||||
continue
|
||||
|
||||
# ein_info: str
|
||||
ann = cst.AnnAssign(
|
||||
target=cst.Name(p["key"]),
|
||||
annotation=cst.Annotation(cst.Name("str")),
|
||||
@@ -337,21 +403,24 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
|
||||
)
|
||||
|
||||
|
||||
def _build_cached_params(values: dict[str, Any]) -> cst.SimpleStatementLine:
|
||||
def _build_generated_model_from_schema(schema_code: str) -> cst.ClassDef | None:
|
||||
"""
|
||||
Make a CST for:
|
||||
cached_parameters = WorkflowParameters(ein_info="...", ...)
|
||||
Parse the generated schema code and return a ClassDef, or None if parsing fails.
|
||||
"""
|
||||
call = cst.Call(
|
||||
func=cst.Name("WorkflowParameters"),
|
||||
args=[cst.Arg(keyword=cst.Name(k), value=_value(v)) for k, v in values.items()],
|
||||
)
|
||||
try:
|
||||
# Parse the schema code and extract just the class definition
|
||||
parsed_module = cst.parse_module(schema_code)
|
||||
|
||||
assign = cst.Assign(
|
||||
targets=[cst.AssignTarget(cst.Name("cached_parameters"))],
|
||||
value=call,
|
||||
)
|
||||
return cst.SimpleStatementLine([assign])
|
||||
# Find the GeneratedWorkflowParameters class in the parsed module
|
||||
for node in parsed_module.body:
|
||||
if isinstance(node, cst.ClassDef) and node.name.value == "GeneratedWorkflowParameters":
|
||||
return node
|
||||
|
||||
# If no class found, return None
|
||||
return None
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to parse generated schema code", error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
@@ -804,7 +873,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
|
||||
cst.parse_statement(
|
||||
"parameters = parameters.model_dump() if isinstance(parameters, WorkflowParameters) else parameters"
|
||||
),
|
||||
cst.parse_statement("page, context = await skyvern.setup(parameters)"),
|
||||
cst.parse_statement("page, context = await skyvern.setup(parameters, GeneratedWorkflowParameters)"),
|
||||
]
|
||||
|
||||
for block in blocks:
|
||||
@@ -867,8 +936,27 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
|
||||
params=[
|
||||
Param(
|
||||
name=cst.Name("parameters"),
|
||||
annotation=cst.Annotation(cst.Name("WorkflowParameters")),
|
||||
default=cst.Name("cached_parameters"),
|
||||
annotation=cst.Annotation(
|
||||
cst.BinaryOperation(
|
||||
left=cst.Name("WorkflowParameters"),
|
||||
operator=cst.BitOr(
|
||||
whitespace_before=cst.SimpleWhitespace(" "),
|
||||
whitespace_after=cst.SimpleWhitespace(" "),
|
||||
),
|
||||
right=cst.Subscript(
|
||||
value=cst.Name("dict"),
|
||||
slice=[
|
||||
cst.SubscriptElement(
|
||||
slice=cst.Index(value=cst.Name("str")),
|
||||
comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
|
||||
),
|
||||
cst.SubscriptElement(
|
||||
slice=cst.Index(value=cst.Name("Any")),
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
),
|
||||
whitespace_after_param=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
@@ -948,11 +1036,24 @@ async def generate_workflow_script(
|
||||
imports: list[cst.BaseStatement] = [
|
||||
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]),
|
||||
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("pydantic"))])]),
|
||||
cst.SimpleStatementLine(
|
||||
[
|
||||
cst.ImportFrom(
|
||||
module=cst.Name("typing"),
|
||||
names=[
|
||||
cst.ImportAlias(cst.Name("Any")),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
cst.SimpleStatementLine(
|
||||
[
|
||||
cst.ImportFrom(
|
||||
module=cst.Name("pydantic"),
|
||||
names=[cst.ImportAlias(cst.Name("BaseModel"))],
|
||||
names=[
|
||||
cst.ImportAlias(cst.Name("BaseModel")),
|
||||
cst.ImportAlias(cst.Name("Field")),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
@@ -964,15 +1065,20 @@ async def generate_workflow_script(
|
||||
names=[
|
||||
cst.ImportAlias(cst.Name("RunContext")),
|
||||
cst.ImportAlias(cst.Name("SkyvernPage")),
|
||||
cst.ImportAlias(cst.Name("generate_text")),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
# --- generate schema and hydrate actions ---------------------------
|
||||
generated_schema, field_mappings = await generate_workflow_parameters_schema(actions_by_task)
|
||||
actions_by_task = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
|
||||
|
||||
# --- class + cached params -----------------------------------------
|
||||
model_cls = _build_model(workflow)
|
||||
cached_params_stmt = _build_cached_params(workflow_run_request.get("parameters", {}))
|
||||
generated_model_cls = _build_generated_model_from_schema(generated_schema)
|
||||
|
||||
# --- blocks ---------------------------------------------------------
|
||||
block_fns = []
|
||||
@@ -1008,17 +1114,29 @@ async def generate_workflow_script(
|
||||
# --- runner ---------------------------------------------------------
|
||||
run_fn = _build_run_fn(blocks, workflow_run_request)
|
||||
|
||||
module = cst.Module(
|
||||
body=[
|
||||
*imports,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
model_cls,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
cached_params_stmt,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
# Build module body with optional generated model class
|
||||
module_body = [
|
||||
*imports,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
model_cls,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
]
|
||||
|
||||
# Add generated model class if available
|
||||
if generated_model_cls:
|
||||
module_body.extend(
|
||||
[
|
||||
generated_model_cls,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
]
|
||||
)
|
||||
|
||||
# Continue with the rest of the module
|
||||
module_body.extend(
|
||||
[
|
||||
*block_fns,
|
||||
cst.EmptyLine(),
|
||||
cst.EmptyLine(),
|
||||
@@ -1029,6 +1147,8 @@ async def generate_workflow_script(
|
||||
]
|
||||
)
|
||||
|
||||
module = cst.Module(body=module_body)
|
||||
|
||||
with open(file_name, "w") as f:
|
||||
f.write(module.code)
|
||||
return module.code
|
||||
|
||||
193
skyvern/core/script_generations/generate_workflow_parameters.py
Normal file
193
skyvern/core/script_generations/generate_workflow_parameters.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.prompting import PromptEngine
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
# Initialize prompt engine
|
||||
prompt_engine = PromptEngine("skyvern")
|
||||
|
||||
|
||||
class GeneratedFieldMapping(BaseModel):
|
||||
"""Mapping of action indices to field names."""
|
||||
|
||||
field_mappings: Dict[str, str]
|
||||
schema_fields: Dict[str, Dict[str, str]]
|
||||
|
||||
|
||||
async def generate_workflow_parameters_schema(
|
||||
actions_by_task: Dict[str, List[Dict[str, Any]]],
|
||||
) -> Tuple[str, Dict[str, str]]:
|
||||
"""
|
||||
Generate a GeneratedWorkflowParameters Pydantic schema based on input_text actions.
|
||||
|
||||
Args:
|
||||
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
||||
|
||||
Returns:
|
||||
Tuple of (schema_code, field_mappings) where:
|
||||
- schema_code: Python code for the GeneratedWorkflowParameters class
|
||||
- field_mappings: Dictionary mapping action indices to field names for hydration
|
||||
"""
|
||||
# Extract all input_text actions
|
||||
input_actions = []
|
||||
action_index_map = {}
|
||||
action_counter = 1
|
||||
|
||||
for task_id, actions in actions_by_task.items():
|
||||
for action in actions:
|
||||
if action.get("action_type") == ActionType.INPUT_TEXT:
|
||||
input_actions.append(
|
||||
{
|
||||
"text": action.get("text", ""),
|
||||
"intention": action.get("intention", ""),
|
||||
"task_id": task_id,
|
||||
"action_id": action.get("action_id", ""),
|
||||
}
|
||||
)
|
||||
action_index_map[f"action_index_{action_counter}"] = {
|
||||
"task_id": task_id,
|
||||
"action_id": action.get("action_id", ""),
|
||||
}
|
||||
action_counter += 1
|
||||
|
||||
if not input_actions:
|
||||
LOG.warning("No input_text actions found in workflow run")
|
||||
return _generate_empty_schema(), {}
|
||||
|
||||
# Generate field names using LLM
|
||||
try:
|
||||
field_mapping = await _generate_field_names_with_llm(input_actions)
|
||||
|
||||
# Generate the Pydantic schema code
|
||||
schema_code = _generate_pydantic_schema(field_mapping.schema_fields)
|
||||
|
||||
# Create field mappings for action hydration
|
||||
action_field_mappings = {}
|
||||
for action_idx, field_name in field_mapping.field_mappings.items():
|
||||
if action_idx in action_index_map:
|
||||
action_info = action_index_map[action_idx]
|
||||
key = f"{action_info['task_id']}:{action_info['action_id']}"
|
||||
action_field_mappings[key] = field_name
|
||||
|
||||
return schema_code, action_field_mappings
|
||||
|
||||
except Exception as e:
|
||||
LOG.error("Failed to generate workflow parameters schema", error=str(e), exc_info=True)
|
||||
return _generate_empty_schema(), {}
|
||||
|
||||
|
||||
async def _generate_field_names_with_llm(input_actions: List[Dict[str, Any]]) -> GeneratedFieldMapping:
|
||||
"""
|
||||
Use LLM to generate field names from input actions.
|
||||
|
||||
Args:
|
||||
input_actions: List of input_text action dictionaries
|
||||
|
||||
Returns:
|
||||
GeneratedFieldMapping with field mappings and schema definitions
|
||||
"""
|
||||
prompt = prompt_engine.load_prompt(template="generate-workflow-parameters", input_actions=input_actions)
|
||||
|
||||
response = await app.LLM_API_HANDLER(prompt=prompt, prompt_name="generate-workflow-parameters")
|
||||
|
||||
return GeneratedFieldMapping.model_validate(response)
|
||||
|
||||
|
||||
def _generate_pydantic_schema(schema_fields: Dict[str, Dict[str, str]]) -> str:
|
||||
"""
|
||||
Generate Pydantic schema code from field definitions.
|
||||
|
||||
Args:
|
||||
schema_fields: Dictionary of field names to their type and description
|
||||
|
||||
Returns:
|
||||
Python code string for the GeneratedWorkflowParameters class
|
||||
"""
|
||||
if not schema_fields:
|
||||
return _generate_empty_schema()
|
||||
|
||||
lines = [
|
||||
"from pydantic import BaseModel, Field",
|
||||
"",
|
||||
"",
|
||||
"class GeneratedWorkflowParameters(BaseModel):",
|
||||
' """Generated schema representing all input_text action values from the workflow run."""',
|
||||
"",
|
||||
]
|
||||
|
||||
for field_name, field_info in schema_fields.items():
|
||||
field_type = field_info.get("type", "str")
|
||||
description = field_info.get("description", f"Value for {field_name}")
|
||||
|
||||
# Escape quotes in description
|
||||
description = description.replace('"', '\\"')
|
||||
|
||||
lines.append(f' {field_name}: {field_type} = Field(description="{description}", default="")')
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _generate_empty_schema() -> str:
|
||||
"""Generate an empty schema when no input_text actions are found."""
|
||||
return '''from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeneratedWorkflowParameters(BaseModel):
|
||||
"""Generated schema representing all input_text action values from the workflow run."""
|
||||
pass
|
||||
'''
|
||||
|
||||
|
||||
def hydrate_input_text_actions_with_field_names(
|
||||
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Add field_name to input_text actions based on generated mappings.
|
||||
|
||||
Args:
|
||||
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
||||
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
||||
|
||||
Returns:
|
||||
Updated actions_by_task with field_name added to input_text actions
|
||||
"""
|
||||
updated_actions_by_task = {}
|
||||
|
||||
for task_id, actions in actions_by_task.items():
|
||||
updated_actions = []
|
||||
|
||||
for action in actions:
|
||||
action_copy = action.copy()
|
||||
|
||||
if action.get("action_type") == ActionType.INPUT_TEXT:
|
||||
action_id = action.get("action_id", "")
|
||||
mapping_key = f"{task_id}:{action_id}"
|
||||
|
||||
if mapping_key in field_mappings:
|
||||
action_copy["field_name"] = field_mappings[mapping_key]
|
||||
else:
|
||||
# Fallback field name if mapping not found
|
||||
intention = action.get("intention", "")
|
||||
if intention:
|
||||
# Simple field name generation from intention
|
||||
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
||||
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
|
||||
action_copy["field_name"] = field_name or "unknown_field"
|
||||
else:
|
||||
action_copy["field_name"] = "unknown_field"
|
||||
|
||||
updated_actions.append(action_copy)
|
||||
|
||||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
@@ -1,11 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
|
||||
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage
|
||||
|
||||
|
||||
async def setup(parameters: dict[str, Any], run_id: str | None = None) -> tuple[SkyvernPage, RunContext]:
|
||||
async def setup(
|
||||
parameters: dict[str, Any], generated_parameter_cls: type[BaseModel] | None = None
|
||||
) -> tuple[SkyvernPage, RunContext]:
|
||||
skyvern_page = await SkyvernPage.create()
|
||||
run_context = RunContext(parameters=parameters, page=skyvern_page)
|
||||
run_context = RunContext(
|
||||
parameters=parameters,
|
||||
page=skyvern_page,
|
||||
# TODO: generate all parameters with llm here - then we can skip generating input text one by one in the fill/type methods
|
||||
generated_parameters=generated_parameter_cls().model_dump() if generated_parameter_cls else None,
|
||||
)
|
||||
script_run_context_manager.set_run_context(run_context)
|
||||
return skyvern_page, run_context
|
||||
|
||||
|
||||
# async def transform_parameters(parameters: dict[str, Any] | BaseModel | None = None, generated_parameter_cls: type[BaseModel] | None = None) -> dict[str, Any] | None:
|
||||
# if parameters is None:
|
||||
# return None
|
||||
|
||||
# if generated_parameter_cls:
|
||||
# if isinstance(parameters, dict):
|
||||
# # TODO: use llm to generate
|
||||
# return generated_parameter_cls.model_validate(parameters)
|
||||
# if isinstance(parameters, BaseModel):
|
||||
# return parameters
|
||||
# return generated_parameter_cls.model_validate(parameters)
|
||||
# else:
|
||||
# if isinstance(parameters, dict):
|
||||
# return parameters
|
||||
# if isinstance(parameters, BaseModel):
|
||||
# return parameters.model_dump()
|
||||
# return parameters
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
@@ -328,29 +329,8 @@ class SkyvernPage:
|
||||
If the prompt generation or parsing fails for any reason we fall back to
|
||||
inputting the originally supplied ``text``.
|
||||
"""
|
||||
new_text = text
|
||||
|
||||
if intention and data:
|
||||
try:
|
||||
# Build the element tree of the current page for the prompt
|
||||
skyvern_context.ensure_context()
|
||||
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
|
||||
script_generation_input_text_prompt = prompt_engine.load_prompt(
|
||||
template="script-generation-input-text-generatiion",
|
||||
intention=intention,
|
||||
data=payload_str,
|
||||
)
|
||||
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
|
||||
prompt=script_generation_input_text_prompt,
|
||||
prompt_name="script-generation-input-text-generatiion",
|
||||
)
|
||||
new_text = json_response.get("answer", text) or text
|
||||
except Exception:
|
||||
# If anything goes wrong, fall back to the original text
|
||||
new_text = text
|
||||
|
||||
locator = self.page.locator(f"xpath={xpath}")
|
||||
await handler_utils.input_sequentially(locator, new_text, timeout=timeout)
|
||||
await handler_utils.input_sequentially(locator, text, timeout=timeout)
|
||||
|
||||
@action_wrap(ActionType.UPLOAD_FILE)
|
||||
async def upload_file(
|
||||
@@ -420,8 +400,8 @@ class SkyvernPage:
|
||||
@action_wrap(ActionType.EXTRACT)
|
||||
async def extract(
|
||||
self,
|
||||
data_extraction_goal: str,
|
||||
data_schema: dict[str, Any] | list | str | None = None,
|
||||
prompt: str,
|
||||
schema: dict[str, Any] | list | str | None = None,
|
||||
error_code_mapping: dict[str, str] | None = None,
|
||||
intention: str | None = None,
|
||||
data: str | dict[str, Any] | None = None,
|
||||
@@ -436,8 +416,8 @@ class SkyvernPage:
|
||||
prompt_engine=prompt_engine,
|
||||
template_name="extract-information",
|
||||
html_need_skyvern_attrs=False,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
extracted_information_schema=data_schema,
|
||||
data_extraction_goal=prompt,
|
||||
extracted_information_schema=schema,
|
||||
current_url=scraped_page_refreshed.url,
|
||||
extracted_text=scraped_page_refreshed.extracted_text,
|
||||
error_code_mapping_str=(json.dumps(error_code_mapping) if error_code_mapping else None),
|
||||
@@ -509,8 +489,14 @@ class SkyvernPage:
|
||||
|
||||
|
||||
class RunContext:
|
||||
def __init__(self, parameters: dict[str, Any], page: SkyvernPage) -> None:
|
||||
self.parameters = parameters
|
||||
def __init__(
|
||||
self, parameters: dict[str, Any], page: SkyvernPage, generated_parameters: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
self.original_parameters = parameters
|
||||
self.generated_parameters = generated_parameters
|
||||
self.parameters = copy.deepcopy(parameters)
|
||||
# if generated_parameters:
|
||||
# self.parameters.update(generated_parameters)
|
||||
self.page = page
|
||||
self.trace: list[ActionCall] = []
|
||||
self.prompt: str | None = None
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
You are an expert at analyzing user interface automation actions and generating meaningful field names for data structures.
|
||||
|
||||
Given a list of input_text actions with their intentions and text values, generate appropriate field names for a Pydantic BaseModel class called "GeneratedWorkflowParameters".
|
||||
|
||||
## Rules:
|
||||
1. Field names should be valid Python identifiers (snake_case, no spaces, no special characters except underscore)
|
||||
2. Field names should be descriptive and based on the intention of the action
|
||||
3. If multiple actions input the same text value, they should map to the same field name
|
||||
4. Field names should be concise but clear about what data they represent
|
||||
5. Avoid generic names like "field1", "input1" - use meaningful names based on the intention
|
||||
|
||||
## Input Actions:
|
||||
{% for action in input_actions %}
|
||||
Action {{ loop.index }}:
|
||||
- Text: "{{ action.text }}"
|
||||
- Intention: "{{ action.intention }}"
|
||||
{% endfor %}
|
||||
|
||||
## Expected Output:
|
||||
Return a JSON object with the following structure:
|
||||
```json
|
||||
{
|
||||
"field_mappings": {
|
||||
"action_index_1": "field_name_1",
|
||||
"action_index_2": "field_name_2",
|
||||
...
|
||||
},
|
||||
"schema_fields": {
|
||||
"field_name_1": {
|
||||
"type": "str",
|
||||
"description": "Description of what this field represents"
|
||||
},
|
||||
"field_name_2": {
|
||||
"type": "str",
|
||||
"description": "Description of what this field represents"
|
||||
},
|
||||
...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Where:
|
||||
- `field_mappings` maps each action index (1-based) to its corresponding field name
|
||||
- `schema_fields` defines each unique field with its type and description
|
||||
- Actions with the same text value should map to the same field name
|
||||
|
||||
Generate the field names now:
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
@@ -14,6 +15,7 @@ from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
|
||||
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
|
||||
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.schemas.files import FileInfo
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskOutput, TaskStatus
|
||||
@@ -417,6 +419,9 @@ async def run_task(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
@@ -427,6 +432,7 @@ async def run_task(
|
||||
await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: fallback to AI run in case of error
|
||||
# Update block status to failed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -437,6 +443,9 @@ async def run_task(
|
||||
failure_reason=str(e),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# clear the prompt in the RunContext
|
||||
run_context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -446,6 +455,7 @@ async def run_task(
|
||||
task_status=TaskStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -461,6 +471,9 @@ async def download(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
@@ -481,6 +494,8 @@ async def download(
|
||||
failure_reason=str(e),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
run_context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -490,6 +505,7 @@ async def download(
|
||||
task_status=TaskStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -505,6 +521,9 @@ async def action(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
@@ -525,6 +544,8 @@ async def action(
|
||||
failure_reason=str(e),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
run_context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -534,6 +555,7 @@ async def action(
|
||||
task_status=TaskStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -549,6 +571,9 @@ async def login(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
@@ -569,6 +594,8 @@ async def login(
|
||||
failure_reason=str(e),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
run_context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -578,6 +605,7 @@ async def login(
|
||||
task_status=TaskStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -593,6 +621,9 @@ async def extract(
|
||||
prompt=prompt,
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
output: dict[str, Any] | list | str | None = None
|
||||
|
||||
if cache_key:
|
||||
@@ -608,7 +639,6 @@ async def extract(
|
||||
output=output,
|
||||
)
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
# Update block status to failed if workflow block was created
|
||||
if workflow_run_block_id:
|
||||
@@ -621,6 +651,8 @@ async def extract(
|
||||
output=output,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
run_context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -630,6 +662,7 @@ async def extract(
|
||||
task_status=TaskStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -688,3 +721,34 @@ async def run_script(
|
||||
await user_script.run_workflow()
|
||||
else:
|
||||
raise Exception(f"No 'run_workflow' function found in {path}")
|
||||
|
||||
|
||||
async def generate_text(
|
||||
text: str | None = None,
|
||||
intention: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
if text:
|
||||
return text
|
||||
new_text = text or ""
|
||||
if intention and data:
|
||||
try:
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
prompt = run_context.prompt
|
||||
# Build the element tree of the current page for the prompt
|
||||
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
|
||||
script_generation_input_text_prompt = prompt_engine.load_prompt(
|
||||
template="script-generation-input-text-generatiion",
|
||||
intention=intention,
|
||||
data=payload_str,
|
||||
goal=prompt,
|
||||
)
|
||||
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
|
||||
prompt=script_generation_input_text_prompt,
|
||||
prompt_name="script-generation-input-text-generatiion",
|
||||
)
|
||||
new_text = json_response.get("answer", new_text)
|
||||
except Exception:
|
||||
# If anything goes wrong, fall back to the original text
|
||||
pass
|
||||
return new_text
|
||||
|
||||
Reference in New Issue
Block a user