diff --git a/skyvern/__init__.py b/skyvern/__init__.py index 3eb37ac1..6626be75 100644 --- a/skyvern/__init__.py +++ b/skyvern/__init__.py @@ -33,6 +33,7 @@ from skyvern.services.script_service import ( # noqa: E402 action, # noqa: E402 download, # noqa: E402 extract, # noqa: E402 + generate_text, # noqa: E402 login, # noqa: E402 run_script, # noqa: E402 run_task, # noqa: E402 @@ -48,6 +49,7 @@ __all__ = [ "cached", "download", "extract", + "generate_text", "login", "run_script", "run_task", diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index 78a50b22..0a489156 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -25,6 +25,10 @@ import structlog from libcst import Attribute, Call, Dict, DictElement, FunctionDef, Name, Param from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS +from skyvern.core.script_generations.generate_workflow_parameters import ( + generate_workflow_parameters_schema, + hydrate_input_text_actions_with_field_names, +) from skyvern.forge import app from skyvern.webeye.actions.action_types import ActionType @@ -61,6 +65,7 @@ ACTIONS_WITH_XPATH = [ ] INDENT = " " * 4 +DOUBLE_INDENT = " " * 8 def _safe_name(label: str) -> str: @@ -97,6 +102,57 @@ def _value(value: Any) -> cst.BaseExpression: return cst.SimpleString(repr(str(value))) +def _generate_text_call(text_value: str, intention: str, parameter_key: str) -> cst.BaseExpression: + """Create a generate_text function call CST expression.""" + return cst.Await( + expression=cst.Call( + func=cst.Name("generate_text"), + whitespace_before_args=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(DOUBLE_INDENT), + ), + args=[ + # First positional argument: context.generated_parameters['parameter_key'] + cst.Arg( + value=cst.Subscript( + value=cst.Attribute( + value=cst.Name("context"), + attr=cst.Name("generated_parameters"), + ), + slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))], + ), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(DOUBLE_INDENT), + ), + ), + # intention keyword argument + cst.Arg( + keyword=cst.Name("intention"), + value=_value(intention), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(DOUBLE_INDENT), + ), + ), + # data keyword argument + cst.Arg( + keyword=cst.Name("data"), + value=cst.Attribute( + value=cst.Name("context"), + attr=cst.Name("parameters"), + ), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + comma=cst.Comma(), + ), + ], + ) + ) + + # --------------------------------------------------------------------- # # 2. utility builders # # --------------------------------------------------------------------- # @@ -177,10 +233,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst. ) if method in ["type", "fill"]: + # Get intention from action + intention = act.get("intention") or act.get("reasoning") or "" + + # Use generate_text call if field_name is available, otherwise fallback to direct value + if act.get("field_name"): + text_value = _generate_text_call( + text_value=act["text"], intention=intention, parameter_key=act["field_name"] + ) + else: + text_value = _value(act["text"]) + args.append( cst.Arg( keyword=cst.Name("text"), - value=_value(act["text"]), + value=text_value, whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), @@ -212,7 +279,7 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst. elif method == "extract": args.append( cst.Arg( - keyword=cst.Name("data_extraction_goal"), + keyword=cst.Name("prompt"), value=_value(act["data_extraction_goal"]), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, @@ -309,8 +376,8 @@ def _build_block_fn(block: dict[str, Any], actions: list[dict[str, Any]]) -> Fun def _build_model(workflow: dict[str, Any]) -> cst.ClassDef: """ class WorkflowParameters(BaseModel): - ein_info: str - company_name: str + param1: str + param2: str ... """ ann_lines: list[cst.BaseStatement] = [] @@ -319,7 +386,6 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef: if p["parameter_type"] != "workflow": continue - # ein_info: str ann = cst.AnnAssign( target=cst.Name(p["key"]), annotation=cst.Annotation(cst.Name("str")), @@ -337,21 +403,24 @@ def _build_model(workflow: dict[str, Any]) -> cst.ClassDef: ) -def _build_cached_params(values: dict[str, Any]) -> cst.SimpleStatementLine: +def _build_generated_model_from_schema(schema_code: str) -> cst.ClassDef | None: """ - Make a CST for: - cached_parameters = WorkflowParameters(ein_info="...", ...) + Parse the generated schema code and return a ClassDef, or None if parsing fails. """ - call = cst.Call( - func=cst.Name("WorkflowParameters"), - args=[cst.Arg(keyword=cst.Name(k), value=_value(v)) for k, v in values.items()], - ) + try: + # Parse the schema code and extract just the class definition + parsed_module = cst.parse_module(schema_code) - assign = cst.Assign( - targets=[cst.AssignTarget(cst.Name("cached_parameters"))], - value=call, - ) - return cst.SimpleStatementLine([assign]) + # Find the GeneratedWorkflowParameters class in the parsed module + for node in parsed_module.body: + if isinstance(node, cst.ClassDef) and node.name.value == "GeneratedWorkflowParameters": + return node + + # If no class found, return None + return None + except Exception as e: + LOG.warning("Failed to parse generated schema code", error=str(e)) + return None # --------------------------------------------------------------------- # @@ -804,7 +873,7 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct cst.parse_statement( "parameters = parameters.model_dump() if isinstance(parameters, WorkflowParameters) else parameters" ), - cst.parse_statement("page, context = await skyvern.setup(parameters)"), + cst.parse_statement("page, context = await skyvern.setup(parameters, GeneratedWorkflowParameters)"), ] for block in blocks: @@ -867,8 +936,27 @@ def _build_run_fn(blocks: list[dict[str, Any]], wf_req: dict[str, Any]) -> Funct params=[ Param( name=cst.Name("parameters"), - annotation=cst.Annotation(cst.Name("WorkflowParameters")), - default=cst.Name("cached_parameters"), + annotation=cst.Annotation( + cst.BinaryOperation( + left=cst.Name("WorkflowParameters"), + operator=cst.BitOr( + whitespace_before=cst.SimpleWhitespace(" "), + whitespace_after=cst.SimpleWhitespace(" "), + ), + right=cst.Subscript( + value=cst.Name("dict"), + slice=[ + cst.SubscriptElement( + slice=cst.Index(value=cst.Name("str")), + comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), + ), + cst.SubscriptElement( + slice=cst.Index(value=cst.Name("Any")), + ), + ], + ), + ) + ), whitespace_after_param=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), @@ -948,11 +1036,24 @@ async def generate_workflow_script( imports: list[cst.BaseStatement] = [ cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("asyncio"))])]), cst.SimpleStatementLine([cst.Import(names=[cst.ImportAlias(cst.Name("pydantic"))])]), + cst.SimpleStatementLine( + [ + cst.ImportFrom( + module=cst.Name("typing"), + names=[ + cst.ImportAlias(cst.Name("Any")), + ], + ) + ] + ), cst.SimpleStatementLine( [ cst.ImportFrom( module=cst.Name("pydantic"), - names=[cst.ImportAlias(cst.Name("BaseModel"))], + names=[ + cst.ImportAlias(cst.Name("BaseModel")), + cst.ImportAlias(cst.Name("Field")), + ], ) ] ), @@ -964,15 +1065,20 @@ async def generate_workflow_script( names=[ cst.ImportAlias(cst.Name("RunContext")), cst.ImportAlias(cst.Name("SkyvernPage")), + cst.ImportAlias(cst.Name("generate_text")), ], ) ] ), ] + # --- generate schema and hydrate actions --------------------------- + generated_schema, field_mappings = await generate_workflow_parameters_schema(actions_by_task) + actions_by_task = hydrate_input_text_actions_with_field_names(actions_by_task, field_mappings) + # --- class + cached params ----------------------------------------- model_cls = _build_model(workflow) - cached_params_stmt = _build_cached_params(workflow_run_request.get("parameters", {})) + generated_model_cls = _build_generated_model_from_schema(generated_schema) # --- blocks --------------------------------------------------------- block_fns = [] @@ -1008,17 +1114,29 @@ async def generate_workflow_script( # --- runner --------------------------------------------------------- run_fn = _build_run_fn(blocks, workflow_run_request) - module = cst.Module( - body=[ - *imports, - cst.EmptyLine(), - cst.EmptyLine(), - model_cls, - cst.EmptyLine(), - cst.EmptyLine(), - cached_params_stmt, - cst.EmptyLine(), - cst.EmptyLine(), + # Build module body with optional generated model class + module_body = [ + *imports, + cst.EmptyLine(), + cst.EmptyLine(), + model_cls, + cst.EmptyLine(), + cst.EmptyLine(), + ] + + # Add generated model class if available + if generated_model_cls: + module_body.extend( + [ + generated_model_cls, + cst.EmptyLine(), + cst.EmptyLine(), + ] + ) + + # Continue with the rest of the module + module_body.extend( + [ *block_fns, cst.EmptyLine(), cst.EmptyLine(), @@ -1029,6 +1147,8 @@ async def generate_workflow_script( ] ) + module = cst.Module(body=module_body) + with open(file_name, "w") as f: f.write(module.code) return module.code diff --git a/skyvern/core/script_generations/generate_workflow_parameters.py b/skyvern/core/script_generations/generate_workflow_parameters.py new file mode 100644 index 00000000..b5c9b345 --- /dev/null +++ b/skyvern/core/script_generations/generate_workflow_parameters.py @@ -0,0 +1,193 @@ +""" +Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions. +""" + +from typing import Any, Dict, List, Tuple + +import structlog +from pydantic import BaseModel + +from skyvern.forge import app +from skyvern.forge.sdk.prompting import PromptEngine +from skyvern.webeye.actions.actions import ActionType + +LOG = structlog.get_logger(__name__) + +# Initialize prompt engine +prompt_engine = PromptEngine("skyvern") + + +class GeneratedFieldMapping(BaseModel): + """Mapping of action indices to field names.""" + + field_mappings: Dict[str, str] + schema_fields: Dict[str, Dict[str, str]] + + +async def generate_workflow_parameters_schema( + actions_by_task: Dict[str, List[Dict[str, Any]]], +) -> Tuple[str, Dict[str, str]]: + """ + Generate a GeneratedWorkflowParameters Pydantic schema based on input_text actions. + + Args: + actions_by_task: Dictionary mapping task IDs to lists of action dictionaries + + Returns: + Tuple of (schema_code, field_mappings) where: + - schema_code: Python code for the GeneratedWorkflowParameters class + - field_mappings: Dictionary mapping action indices to field names for hydration + """ + # Extract all input_text actions + input_actions = [] + action_index_map = {} + action_counter = 1 + + for task_id, actions in actions_by_task.items(): + for action in actions: + if action.get("action_type") == ActionType.INPUT_TEXT: + input_actions.append( + { + "text": action.get("text", ""), + "intention": action.get("intention", ""), + "task_id": task_id, + "action_id": action.get("action_id", ""), + } + ) + action_index_map[f"action_index_{action_counter}"] = { + "task_id": task_id, + "action_id": action.get("action_id", ""), + } + action_counter += 1 + + if not input_actions: + LOG.warning("No input_text actions found in workflow run") + return _generate_empty_schema(), {} + + # Generate field names using LLM + try: + field_mapping = await _generate_field_names_with_llm(input_actions) + + # Generate the Pydantic schema code + schema_code = _generate_pydantic_schema(field_mapping.schema_fields) + + # Create field mappings for action hydration + action_field_mappings = {} + for action_idx, field_name in field_mapping.field_mappings.items(): + if action_idx in action_index_map: + action_info = action_index_map[action_idx] + key = f"{action_info['task_id']}:{action_info['action_id']}" + action_field_mappings[key] = field_name + + return schema_code, action_field_mappings + + except Exception as e: + LOG.error("Failed to generate workflow parameters schema", error=str(e), exc_info=True) + return _generate_empty_schema(), {} + + +async def _generate_field_names_with_llm(input_actions: List[Dict[str, Any]]) -> GeneratedFieldMapping: + """ + Use LLM to generate field names from input actions. + + Args: + input_actions: List of input_text action dictionaries + + Returns: + GeneratedFieldMapping with field mappings and schema definitions + """ + prompt = prompt_engine.load_prompt(template="generate-workflow-parameters", input_actions=input_actions) + + response = await app.LLM_API_HANDLER(prompt=prompt, prompt_name="generate-workflow-parameters") + + return GeneratedFieldMapping.model_validate(response) + + +def _generate_pydantic_schema(schema_fields: Dict[str, Dict[str, str]]) -> str: + """ + Generate Pydantic schema code from field definitions. + + Args: + schema_fields: Dictionary of field names to their type and description + + Returns: + Python code string for the GeneratedWorkflowParameters class + """ + if not schema_fields: + return _generate_empty_schema() + + lines = [ + "from pydantic import BaseModel, Field", + "", + "", + "class GeneratedWorkflowParameters(BaseModel):", + ' """Generated schema representing all input_text action values from the workflow run."""', + "", + ] + + for field_name, field_info in schema_fields.items(): + field_type = field_info.get("type", "str") + description = field_info.get("description", f"Value for {field_name}") + + # Escape quotes in description + description = description.replace('"', '\\"') + + lines.append(f' {field_name}: {field_type} = Field(description="{description}", default="")') + + return "\n".join(lines) + + +def _generate_empty_schema() -> str: + """Generate an empty schema when no input_text actions are found.""" + return '''from pydantic import BaseModel + + +class GeneratedWorkflowParameters(BaseModel): + """Generated schema representing all input_text action values from the workflow run.""" + pass +''' + + +def hydrate_input_text_actions_with_field_names( + actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str] +) -> Dict[str, List[Dict[str, Any]]]: + """ + Add field_name to input_text actions based on generated mappings. + + Args: + actions_by_task: Dictionary mapping task IDs to lists of action dictionaries + field_mappings: Dictionary mapping "task_id:action_id" to field names + + Returns: + Updated actions_by_task with field_name added to input_text actions + """ + updated_actions_by_task = {} + + for task_id, actions in actions_by_task.items(): + updated_actions = [] + + for action in actions: + action_copy = action.copy() + + if action.get("action_type") == ActionType.INPUT_TEXT: + action_id = action.get("action_id", "") + mapping_key = f"{task_id}:{action_id}" + + if mapping_key in field_mappings: + action_copy["field_name"] = field_mappings[mapping_key] + else: + # Fallback field name if mapping not found + intention = action.get("intention", "") + if intention: + # Simple field name generation from intention + field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "") + field_name = "".join(c for c in field_name if c.isalnum() or c == "_") + action_copy["field_name"] = field_name or "unknown_field" + else: + action_copy["field_name"] = "unknown_field" + + updated_actions.append(action_copy) + + updated_actions_by_task[task_id] = updated_actions + + return updated_actions_by_task diff --git a/skyvern/core/script_generations/run_initializer.py b/skyvern/core/script_generations/run_initializer.py index 73af1a70..80b44e9c 100644 --- a/skyvern/core/script_generations/run_initializer.py +++ b/skyvern/core/script_generations/run_initializer.py @@ -1,11 +1,39 @@ from typing import Any +from pydantic import BaseModel + from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage -async def setup(parameters: dict[str, Any], run_id: str | None = None) -> tuple[SkyvernPage, RunContext]: +async def setup( + parameters: dict[str, Any], generated_parameter_cls: type[BaseModel] | None = None +) -> tuple[SkyvernPage, RunContext]: skyvern_page = await SkyvernPage.create() - run_context = RunContext(parameters=parameters, page=skyvern_page) + run_context = RunContext( + parameters=parameters, + page=skyvern_page, + # TODO: generate all parameters with llm here - then we can skip generating input text one by one in the fill/type methods + generated_parameters=generated_parameter_cls().model_dump() if generated_parameter_cls else None, + ) script_run_context_manager.set_run_context(run_context) return skyvern_page, run_context + + +# async def transform_parameters(parameters: dict[str, Any] | BaseModel | None = None, generated_parameter_cls: type[BaseModel] | None = None) -> dict[str, Any] | None: +# if parameters is None: +# return None + +# if generated_parameter_cls: +# if isinstance(parameters, dict): +# # TODO: use llm to generate +# return generated_parameter_cls.model_validate(parameters) +# if isinstance(parameters, BaseModel): +# return parameters +# return generated_parameter_cls.model_validate(parameters) +# else: +# if isinstance(parameters, dict): +# return parameters +# if isinstance(parameters, BaseModel): +# return parameters.model_dump() +# return parameters diff --git a/skyvern/core/script_generations/skyvern_page.py b/skyvern/core/script_generations/skyvern_page.py index 98d1be03..6488732c 100644 --- a/skyvern/core/script_generations/skyvern_page.py +++ b/skyvern/core/script_generations/skyvern_page.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import copy import json from dataclasses import dataclass from datetime import datetime, timezone @@ -328,29 +329,8 @@ class SkyvernPage: If the prompt generation or parsing fails for any reason we fall back to inputting the originally supplied ``text``. """ - new_text = text - - if intention and data: - try: - # Build the element tree of the current page for the prompt - skyvern_context.ensure_context() - payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "") - script_generation_input_text_prompt = prompt_engine.load_prompt( - template="script-generation-input-text-generatiion", - intention=intention, - data=payload_str, - ) - json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER( - prompt=script_generation_input_text_prompt, - prompt_name="script-generation-input-text-generatiion", - ) - new_text = json_response.get("answer", text) or text - except Exception: - # If anything goes wrong, fall back to the original text - new_text = text - locator = self.page.locator(f"xpath={xpath}") - await handler_utils.input_sequentially(locator, new_text, timeout=timeout) + await handler_utils.input_sequentially(locator, text, timeout=timeout) @action_wrap(ActionType.UPLOAD_FILE) async def upload_file( @@ -420,8 +400,8 @@ class SkyvernPage: @action_wrap(ActionType.EXTRACT) async def extract( self, - data_extraction_goal: str, - data_schema: dict[str, Any] | list | str | None = None, + prompt: str, + schema: dict[str, Any] | list | str | None = None, error_code_mapping: dict[str, str] | None = None, intention: str | None = None, data: str | dict[str, Any] | None = None, @@ -436,8 +416,8 @@ class SkyvernPage: prompt_engine=prompt_engine, template_name="extract-information", html_need_skyvern_attrs=False, - data_extraction_goal=data_extraction_goal, - extracted_information_schema=data_schema, + data_extraction_goal=prompt, + extracted_information_schema=schema, current_url=scraped_page_refreshed.url, extracted_text=scraped_page_refreshed.extracted_text, error_code_mapping_str=(json.dumps(error_code_mapping) if error_code_mapping else None), @@ -509,8 +489,14 @@ class SkyvernPage: class RunContext: - def __init__(self, parameters: dict[str, Any], page: SkyvernPage) -> None: - self.parameters = parameters + def __init__( + self, parameters: dict[str, Any], page: SkyvernPage, generated_parameters: dict[str, Any] | None = None + ) -> None: + self.original_parameters = parameters + self.generated_parameters = generated_parameters + self.parameters = copy.deepcopy(parameters) + # if generated_parameters: + # self.parameters.update(generated_parameters) self.page = page self.trace: list[ActionCall] = [] self.prompt: str | None = None diff --git a/skyvern/forge/prompts/skyvern/generate-workflow-parameters.j2 b/skyvern/forge/prompts/skyvern/generate-workflow-parameters.j2 new file mode 100644 index 00000000..b86b8170 --- /dev/null +++ b/skyvern/forge/prompts/skyvern/generate-workflow-parameters.j2 @@ -0,0 +1,47 @@ +You are an expert at analyzing user interface automation actions and generating meaningful field names for data structures. + +Given a list of input_text actions with their intentions and text values, generate appropriate field names for a Pydantic BaseModel class called "GeneratedWorkflowParameters". + +## Rules: +1. Field names should be valid Python identifiers (snake_case, no spaces, no special characters except underscore) +2. Field names should be descriptive and based on the intention of the action +3. If multiple actions input the same text value, they should map to the same field name +4. Field names should be concise but clear about what data they represent +5. Avoid generic names like "field1", "input1" - use meaningful names based on the intention + +## Input Actions: +{% for action in input_actions %} +Action {{ loop.index }}: +- Text: "{{ action.text }}" +- Intention: "{{ action.intention }}" +{% endfor %} + +## Expected Output: +Return a JSON object with the following structure: +```json +{ + "field_mappings": { + "action_index_1": "field_name_1", + "action_index_2": "field_name_2", + ... + }, + "schema_fields": { + "field_name_1": { + "type": "str", + "description": "Description of what this field represents" + }, + "field_name_2": { + "type": "str", + "description": "Description of what this field represents" + }, + ... + } +} +``` + +Where: +- `field_mappings` maps each action index (1-based) to its corresponding field name +- `schema_fields` defines each unique field with its type and description +- Actions with the same text value should map to the same field name + +Generate the field names now: diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index 85d4fd0d..ed6ac6d0 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -2,6 +2,7 @@ import asyncio import base64 import hashlib import importlib.util +import json import os from datetime import datetime from typing import Any, cast @@ -14,6 +15,7 @@ from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound from skyvern.forge import app +from skyvern.forge.prompts import prompt_engine from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.tasks import TaskOutput, TaskStatus @@ -417,6 +419,9 @@ async def run_task( prompt=prompt, url=url, ) + # set the prompt in the RunContext + run_context = script_run_context_manager.ensure_run_context() + run_context.prompt = prompt if cache_key: try: @@ -427,6 +432,7 @@ async def run_task( await _update_workflow_block(workflow_run_block_id, BlockStatus.completed, task_id=task_id) except Exception as e: + # TODO: fallback to AI run in case of error # Update block status to failed if workflow block was created if workflow_run_block_id: await _update_workflow_block( @@ -437,6 +443,9 @@ async def run_task( failure_reason=str(e), ) raise + finally: + # clear the prompt in the RunContext + run_context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -446,6 +455,7 @@ async def run_task( task_status=TaskStatus.failed, failure_reason="Cache key is required", ) + run_context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -461,6 +471,9 @@ async def download( prompt=prompt, url=url, ) + # set the prompt in the RunContext + run_context = script_run_context_manager.ensure_run_context() + run_context.prompt = prompt if cache_key: try: @@ -481,6 +494,8 @@ async def download( failure_reason=str(e), ) raise + finally: + run_context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -490,6 +505,7 @@ async def download( task_status=TaskStatus.failed, failure_reason="Cache key is required", ) + run_context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -505,6 +521,9 @@ async def action( prompt=prompt, url=url, ) + # set the prompt in the RunContext + run_context = script_run_context_manager.ensure_run_context() + run_context.prompt = prompt if cache_key: try: @@ -525,6 +544,8 @@ async def action( failure_reason=str(e), ) raise + finally: + run_context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -534,6 +555,7 @@ async def action( task_status=TaskStatus.failed, failure_reason="Cache key is required", ) + run_context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -549,6 +571,9 @@ async def login( prompt=prompt, url=url, ) + # set the prompt in the RunContext + run_context = script_run_context_manager.ensure_run_context() + run_context.prompt = prompt if cache_key: try: @@ -569,6 +594,8 @@ async def login( failure_reason=str(e), ) raise + finally: + run_context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -578,6 +605,7 @@ async def login( task_status=TaskStatus.failed, failure_reason="Cache key is required", ) + run_context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -593,6 +621,9 @@ async def extract( prompt=prompt, url=url, ) + # set the prompt in the RunContext + run_context = script_run_context_manager.ensure_run_context() + run_context.prompt = prompt output: dict[str, Any] | list | str | None = None if cache_key: @@ -608,7 +639,6 @@ async def extract( output=output, ) return output - except Exception as e: # Update block status to failed if workflow block was created if workflow_run_block_id: @@ -621,6 +651,8 @@ async def extract( output=output, ) raise + finally: + run_context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -630,6 +662,7 @@ async def extract( task_status=TaskStatus.failed, failure_reason="Cache key is required", ) + run_context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -688,3 +721,34 @@ async def run_script( await user_script.run_workflow() else: raise Exception(f"No 'run_workflow' function found in {path}") + + +async def generate_text( + text: str | None = None, + intention: str | None = None, + data: dict[str, Any] | None = None, +) -> str: + if text: + return text + new_text = text or "" + if intention and data: + try: + run_context = script_run_context_manager.ensure_run_context() + prompt = run_context.prompt + # Build the element tree of the current page for the prompt + payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "") + script_generation_input_text_prompt = prompt_engine.load_prompt( + template="script-generation-input-text-generatiion", + intention=intention, + data=payload_str, + goal=prompt, + ) + json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER( + prompt=script_generation_input_text_prompt, + prompt_name="script-generation-input-text-generatiion", + ) + new_text = json_response.get("answer", new_text) + except Exception: + # If anything goes wrong, fall back to the original text + pass + return new_text