generate GeneratedWorkflowParameters (#3264)

This commit is contained in:
Shuchang Zheng
2025-08-21 15:42:34 -07:00
committed by GitHub
parent 988416829f
commit 2a62dc08aa
7 changed files with 504 additions and 64 deletions

View File

@@ -33,6 +33,7 @@ from skyvern.services.script_service import ( # noqa: E402
action, # noqa: E402
download, # noqa: E402
extract, # noqa: E402
generate_text, # noqa: E402
login, # noqa: E402
run_script, # noqa: E402
run_task, # noqa: E402
@@ -48,6 +49,7 @@ __all__ = [
"cached",
"download",
"extract",
"generate_text",
"login",
"run_script",
"run_task",

View File

@@ -25,6 +25,10 @@ import structlog
from libcst import Attribute, Call, Dict, DictElement, FunctionDef, Name, Param
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_workflow_parameters import (
generate_workflow_parameters_schema,
hydrate_input_text_actions_with_field_names,
)
from skyvern.forge import app
from skyvern.webeye.actions.action_types import ActionType
@@ -61,6 +65,7 @@ ACTIONS_WITH_XPATH = [
]
INDENT = " " * 4
DOUBLE_INDENT = " " * 8
def _safe_name(label: str) -> str:
@@ -97,6 +102,57 @@ def _value(value: Any) -> cst.BaseExpression:
return cst.SimpleString(repr(str(value)))
def _generate_text_call(text_value: str, intention: str, parameter_key: str) -> cst.BaseExpression:
"""Create a generate_text function call CST expression."""
return cst.Await(
expression=cst.Call(
func=cst.Name("generate_text"),
whitespace_before_args=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
args=[
# First positional argument: context.generated_parameters['parameter_key']
cst.Arg(
value=cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("generated_parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))],
),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
),
# intention keyword argument
cst.Arg(
keyword=cst.Name("intention"),
value=_value(intention),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
),
# data keyword argument
cst.Arg(
keyword=cst.Name("data"),
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
comma=cst.Comma(),
),
],
)
)
# --------------------------------------------------------------------- #
# 2. utility builders #
# --------------------------------------------------------------------- #
@@ -177,10 +233,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
)
if method in ["type", "fill"]:
# Get intention from action
intention = act.get("intention") or act.get("reasoning") or ""
# Use generate_text call if field_name is available, otherwise fallback to direct value
if act.get("field_name"):
text_value = _generate_text_call(
text_value=act["text"], intention=intention, parameter_key=act["field_name"]
)
else:
text_value = _value(act["text"])
args.append(
cst.Arg(
keyword=cst.Name("text"),
value=_value(act["text"]),
value=text_value,
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
@@ -212,7 +279,7 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
elif method == "extract":
args.append(
cst.Arg(
keyword=cst.Name("data_extraction_goal"),
keyword=cst.Name("prompt"),
value=_value(act["data_extraction_goal"]),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
@@ -309,8 +376,8 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun
def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
"""
class WorkflowParameters(BaseModel):
ein_info: str
company_name: str
param1: str
param2: str
...
"""
ann_lines: list[cst.BaseStatement] = []
@@ -319,7 +386,6 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
if p["parameter_type"] != "workflow":
continue
# ein_info: str
ann = cst.AnnAssign(
target=cst.Name(p["key"]),
annotation=cst.Annotation(cst.Name("str")),
@@ -337,21 +403,24 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef:
)
def _build_cached_params(values: dict[str, Any]) -> cst.SimpleStatementLine:
def _build_generated_model_from_schema(schema_code: str) -> cst.ClassDef | None:
"""
Make a CST for:
cached_parameters = WorkflowParameters(ein_info="...", ...)
Parse the generated schema code and return a ClassDef, or None if parsing fails.
"""
call = cst.Call(
func=cst.Name("WorkflowParameters"),
args=[cst.Arg(keyword=cst.Name(k), value=_value(v)) for k, v in values.items()],
)
try:
# Parse the schema code and extract just the class definition
parsed_module = cst.parse_module(schema_code)
assign = cst.Assign(
targets=[cst.AssignTarget(cst.Name("cached_parameters"))],
value=call,
)
return cst.SimpleStatementLine([assign])
# Find the GeneratedWorkflowParameters class in the parsed module
for node in parsed_module.body:
if isinstance(node, cst.ClassDef) and node.name.value == "GeneratedWorkflowParameters":
return node
# If no class found, return None
return None
except Exception as e:
LOG.warning("Failed to parse generated schema code", error=str(e))
return None
# --------------------------------------------------------------------- #
@@ -804,7 +873,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
cst.parse_statement(
"parameters = parameters.model_dump() if isinstance(parameters, WorkflowParameters) else parameters"
),
cst.parse_statement("page, context = await skyvern.setup(parameters)"),
cst.parse_statement("page, context = await skyvern.setup(parameters, GeneratedWorkflowParameters)"),
]
for block in blocks:
@@ -867,8 +936,27 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct
params=[
Param(
name=cst.Name("parameters"),
annotation=cst.Annotation(cst.Name("WorkflowParameters")),
default=cst.Name("cached_parameters"),
annotation=cst.Annotation(
cst.BinaryOperation(
left=cst.Name("WorkflowParameters"),
operator=cst.BitOr(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
),
right=cst.Subscript(
value=cst.Name("dict"),
slice=[
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("str")),
comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
),
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("Any")),
),
],
),
)
),
whitespace_after_param=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
@@ -948,11 +1036,24 @@ async def generate_workflow_script(
imports: list[cst.BaseStatement] = [
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]),
cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("pydantic"))])]),
cst.SimpleStatementLine(
[
cst.ImportFrom(
module=cst.Name("typing"),
names=[
cst.ImportAlias(cst.Name("Any")),
],
)
]
),
cst.SimpleStatementLine(
[
cst.ImportFrom(
module=cst.Name("pydantic"),
names=[cst.ImportAlias(cst.Name("BaseModel"))],
names=[
cst.ImportAlias(cst.Name("BaseModel")),
cst.ImportAlias(cst.Name("Field")),
],
)
]
),
@@ -964,15 +1065,20 @@ async def generate_workflow_script(
names=[
cst.ImportAlias(cst.Name("RunContext")),
cst.ImportAlias(cst.Name("SkyvernPage")),
cst.ImportAlias(cst.Name("generate_text")),
],
)
]
),
]
# --- generate schema and hydrate actions ---------------------------
generated_schema, field_mappings = await generate_workflow_parameters_schema(actions_by_task)
actions_by_task = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings)
# --- class + cached params -----------------------------------------
model_cls = _build_model(workflow)
cached_params_stmt = _build_cached_params(workflow_run_request.get("parameters", {}))
generated_model_cls = _build_generated_model_from_schema(generated_schema)
# --- blocks ---------------------------------------------------------
block_fns = []
@@ -1008,17 +1114,29 @@ async def generate_workflow_script(
# --- runner ---------------------------------------------------------
run_fn = _build_run_fn(blocks, workflow_run_request)
module = cst.Module(
body=[
*imports,
cst.EmptyLine(),
cst.EmptyLine(),
model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
cached_params_stmt,
cst.EmptyLine(),
cst.EmptyLine(),
# Build module body with optional generated model class
module_body = [
*imports,
cst.EmptyLine(),
cst.EmptyLine(),
model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
]
# Add generated model class if available
if generated_model_cls:
module_body.extend(
[
generated_model_cls,
cst.EmptyLine(),
cst.EmptyLine(),
]
)
# Continue with the rest of the module
module_body.extend(
[
*block_fns,
cst.EmptyLine(),
cst.EmptyLine(),
@@ -1029,6 +1147,8 @@ async def generate_workflow_script(
]
)
module = cst.Module(body=module_body)
with open(file_name, "w") as f:
f.write(module.code)
return module.code

View File

@@ -0,0 +1,193 @@
"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""
from typing import Any, Dict, List, Tuple
import structlog
from pydantic import BaseModel
from skyvern.forge import app
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType
LOG = structlog.get_logger(__name__)
# Initialize prompt engine
prompt_engine = PromptEngine("skyvern")
class GeneratedFieldMapping(BaseModel):
"""Mapping of action indices to field names."""
field_mappings: Dict[str, str]
schema_fields: Dict[str, Dict[str, str]]
async def generate_workflow_parameters_schema(
actions_by_task: Dict[str, List[Dict[str, Any]]],
) -> Tuple[str, Dict[str, str]]:
"""
Generate a GeneratedWorkflowParameters Pydantic schema based on input_text actions.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
Returns:
Tuple of (schema_code, field_mappings) where:
- schema_code: Python code for the GeneratedWorkflowParameters class
- field_mappings: Dictionary mapping action indices to field names for hydration
"""
# Extract all input_text actions
input_actions = []
action_index_map = {}
action_counter = 1
for task_id, actions in actions_by_task.items():
for action in actions:
if action.get("action_type") == ActionType.INPUT_TEXT:
input_actions.append(
{
"text": action.get("text", ""),
"intention": action.get("intention", ""),
"task_id": task_id,
"action_id": action.get("action_id", ""),
}
)
action_index_map[f"action_index_{action_counter}"] = {
"task_id": task_id,
"action_id": action.get("action_id", ""),
}
action_counter += 1
if not input_actions:
LOG.warning("No input_text actions found in workflow run")
return _generate_empty_schema(), {}
# Generate field names using LLM
try:
field_mapping = await _generate_field_names_with_llm(input_actions)
# Generate the Pydantic schema code
schema_code = _generate_pydantic_schema(field_mapping.schema_fields)
# Create field mappings for action hydration
action_field_mappings = {}
for action_idx, field_name in field_mapping.field_mappings.items():
if action_idx in action_index_map:
action_info = action_index_map[action_idx]
key = f"{action_info['task_id']}:{action_info['action_id']}"
action_field_mappings[key] = field_name
return schema_code, action_field_mappings
except Exception as e:
LOG.error("Failed to generate workflow parameters schema", error=str(e), exc_info=True)
return _generate_empty_schema(), {}
async def _generate_field_names_with_llm(input_actions: List[Dict[str, Any]]) -> GeneratedFieldMapping:
"""
Use LLM to generate field names from input actions.
Args:
input_actions: List of input_text action dictionaries
Returns:
GeneratedFieldMapping with field mappings and schema definitions
"""
prompt = prompt_engine.load_prompt(template="generate-workflow-parameters", input_actions=input_actions)
response = await app.LLM_API_HANDLER(prompt=prompt, prompt_name="generate-workflow-parameters")
return GeneratedFieldMapping.model_validate(response)
def _generate_pydantic_schema(schema_fields: Dict[str, Dict[str, str]]) -> str:
"""
Generate Pydantic schema code from field definitions.
Args:
schema_fields: Dictionary of field names to their type and description
Returns:
Python code string for the GeneratedWorkflowParameters class
"""
if not schema_fields:
return _generate_empty_schema()
lines = [
"from pydantic import BaseModel, Field",
"",
"",
"class GeneratedWorkflowParameters(BaseModel):",
' """Generated schema representing all input_text action values from the workflow run."""',
"",
]
for field_name, field_info in schema_fields.items():
field_type = field_info.get("type", "str")
description = field_info.get("description", f"Value for {field_name}")
# Escape quotes in description
description = description.replace('"', '\\"')
lines.append(f' {field_name}: {field_type} = Field(description="{description}", default="")')
return "\n".join(lines)
def _generate_empty_schema() -> str:
"""Generate an empty schema when no input_text actions are found."""
return '''from pydantic import BaseModel
class GeneratedWorkflowParameters(BaseModel):
"""Generated schema representing all input_text action values from the workflow run."""
pass
'''
def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names
Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {}
for task_id, actions in actions_by_task.items():
updated_actions = []
for action in actions:
action_copy = action.copy()
if action.get("action_type") == ActionType.INPUT_TEXT:
action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}"
if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy)
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task

View File

@@ -1,11 +1,39 @@
from typing import Any
from pydantic import BaseModel
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage
async def setup(parameters: dict[str, Any], run_id: str | None = None) -> tuple[SkyvernPage, RunContext]:
async def setup(
parameters: dict[str, Any], generated_parameter_cls: type[BaseModel] | None = None
) -> tuple[SkyvernPage, RunContext]:
skyvern_page = await SkyvernPage.create()
run_context = RunContext(parameters=parameters, page=skyvern_page)
run_context = RunContext(
parameters=parameters,
page=skyvern_page,
# TODO: generate all parameters with llm here - then we can skip generating input text one by one in the fill/type methods
generated_parameters=generated_parameter_cls().model_dump() if generated_parameter_cls else None,
)
script_run_context_manager.set_run_context(run_context)
return skyvern_page, run_context
# async def transform_parameters(parameters: dict[str, Any] | BaseModel | None = None, generated_parameter_cls: type[BaseModel] | None = None) -> dict[str, Any] | None:
# if parameters is None:
# return None
# if generated_parameter_cls:
# if isinstance(parameters, dict):
# # TODO: use llm to generate
# return generated_parameter_cls.model_validate(parameters)
# if isinstance(parameters, BaseModel):
# return parameters
# return generated_parameter_cls.model_validate(parameters)
# else:
# if isinstance(parameters, dict):
# return parameters
# if isinstance(parameters, BaseModel):
# return parameters.model_dump()
# return parameters

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import copy
import json
from dataclasses import dataclass
from datetime import datetime, timezone
@@ -328,29 +329,8 @@ class SkyvernPage:
If the prompt generation or parsing fails for any reason we fall back to
inputting the originally supplied ``text``.
"""
new_text = text
if intention and data:
try:
# Build the element tree of the current page for the prompt
skyvern_context.ensure_context()
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt(
template="script-generation-input-text-generatiion",
intention=intention,
data=payload_str,
)
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
prompt=script_generation_input_text_prompt,
prompt_name="script-generation-input-text-generatiion",
)
new_text = json_response.get("answer", text) or text
except Exception:
# If anything goes wrong, fall back to the original text
new_text = text
locator = self.page.locator(f"xpath={xpath}")
await handler_utils.input_sequentially(locator, new_text, timeout=timeout)
await handler_utils.input_sequentially(locator, text, timeout=timeout)
@action_wrap(ActionType.UPLOAD_FILE)
async def upload_file(
@@ -420,8 +400,8 @@ class SkyvernPage:
@action_wrap(ActionType.EXTRACT)
async def extract(
self,
data_extraction_goal: str,
data_schema: dict[str, Any] | list | str | None = None,
prompt: str,
schema: dict[str, Any] | list | str | None = None,
error_code_mapping: dict[str, str] | None = None,
intention: str | None = None,
data: str | dict[str, Any] | None = None,
@@ -436,8 +416,8 @@ class SkyvernPage:
prompt_engine=prompt_engine,
template_name="extract-information",
html_need_skyvern_attrs=False,
data_extraction_goal=data_extraction_goal,
extracted_information_schema=data_schema,
data_extraction_goal=prompt,
extracted_information_schema=schema,
current_url=scraped_page_refreshed.url,
extracted_text=scraped_page_refreshed.extracted_text,
error_code_mapping_str=(json.dumps(error_code_mapping) if error_code_mapping else None),
@@ -509,8 +489,14 @@ class SkyvernPage:
class RunContext:
def __init__(self, parameters: dict[str, Any], page: SkyvernPage) -> None:
self.parameters = parameters
def __init__(
self, parameters: dict[str, Any], page: SkyvernPage, generated_parameters: dict[str, Any] | None = None
) -> None:
self.original_parameters = parameters
self.generated_parameters = generated_parameters
self.parameters = copy.deepcopy(parameters)
# if generated_parameters:
# self.parameters.update(generated_parameters)
self.page = page
self.trace: list[ActionCall] = []
self.prompt: str | None = None

View File

@@ -0,0 +1,47 @@
You are an expert at analyzing user interface automation actions and generating meaningful field names for data structures.
Given a list of input_text actions with their intentions and text values, generate appropriate field names for a Pydantic BaseModel class called "GeneratedWorkflowParameters".
## Rules:
1. Field names should be valid Python identifiers (snake_case, no spaces, no special characters except underscore)
2. Field names should be descriptive and based on the intention of the action
3. If multiple actions input the same text value, they should map to the same field name
4. Field names should be concise but clear about what data they represent
5. Avoid generic names like "field1", "input1" - use meaningful names based on the intention
## Input Actions:
{% for action in input_actions %}
Action {{ loop.index }}:
- Text: "{{ action.text }}"
- Intention: "{{ action.intention }}"
{% endfor %}
## Expected Output:
Return a JSON object with the following structure:
```json
{
"field_mappings": {
"action_index_1": "field_name_1",
"action_index_2": "field_name_2",
...
},
"schema_fields": {
"field_name_1": {
"type": "str",
"description": "Description of what this field represents"
},
"field_name_2": {
"type": "str",
"description": "Description of what this field represents"
},
...
}
}
```
Where:
- `field_mappings` maps each action index (1-based) to its corresponding field name
- `schema_fields` defines each unique field with its type and description
- Actions with the same text value should map to the same field name
Generate the field names now:

View File

@@ -2,6 +2,7 @@ import asyncio
import base64
import hashlib
import importlib.util
import json
import os
from datetime import datetime
from typing import Any, cast
@@ -14,6 +15,7 @@ from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.tasks import TaskOutput, TaskStatus
@@ -417,6 +419,9 @@ async def run_task(
prompt=prompt,
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
if cache_key:
try:
@@ -427,6 +432,7 @@ async def run_task(
await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id)
except Exception as e:
# TODO: fallback to AI run in case of error
# Update block status to failed if workflow block was created
if workflow_run_block_id:
await _update_workflow_block(
@@ -437,6 +443,9 @@ async def run_task(
failure_reason=str(e),
)
raise
finally:
# clear the prompt in the RunContext
run_context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -446,6 +455,7 @@ async def run_task(
task_status=TaskStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -461,6 +471,9 @@ async def download(
prompt=prompt,
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
if cache_key:
try:
@@ -481,6 +494,8 @@ async def download(
failure_reason=str(e),
)
raise
finally:
run_context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -490,6 +505,7 @@ async def download(
task_status=TaskStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -505,6 +521,9 @@ async def action(
prompt=prompt,
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
if cache_key:
try:
@@ -525,6 +544,8 @@ async def action(
failure_reason=str(e),
)
raise
finally:
run_context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -534,6 +555,7 @@ async def action(
task_status=TaskStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -549,6 +571,9 @@ async def login(
prompt=prompt,
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
if cache_key:
try:
@@ -569,6 +594,8 @@ async def login(
failure_reason=str(e),
)
raise
finally:
run_context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -578,6 +605,7 @@ async def login(
task_status=TaskStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -593,6 +621,9 @@ async def extract(
prompt=prompt,
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
output: dict[str, Any] | list | str | None = None
if cache_key:
@@ -608,7 +639,6 @@ async def extract(
output=output,
)
return output
except Exception as e:
# Update block status to failed if workflow block was created
if workflow_run_block_id:
@@ -621,6 +651,8 @@ async def extract(
output=output,
)
raise
finally:
run_context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -630,6 +662,7 @@ async def extract(
task_status=TaskStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -688,3 +721,34 @@ async def run_script(
await user_script.run_workflow()
else:
raise Exception(f"No 'run_workflow' function found in {path}")
async def generate_text(
text: str | None = None,
intention: str | None = None,
data: dict[str, Any] | None = None,
) -> str:
if text:
return text
new_text = text or ""
if intention and data:
try:
run_context = script_run_context_manager.ensure_run_context()
prompt = run_context.prompt
# Build the element tree of the current page for the prompt
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt(
template="script-generation-input-text-generatiion",
intention=intention,
data=payload_str,
goal=prompt,
)
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
prompt=script_generation_input_text_prompt,
prompt_name="script-generation-input-text-generatiion",
)
new_text = json_response.get("answer", new_text)
except Exception:
# If anything goes wrong, fall back to the original text
pass
return new_text