From 55d847461ecb12a2809f23852f80cc157ea3b319 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Wed, 3 Sep 2025 16:44:52 -0700 Subject: [PATCH] ai_adapt_value for text input (#3354) --- .../script_generations/generate_script.py | 29 +++++-- .../script_generations/run_initializer.py | 3 +- .../script_run_context_manager.py | 34 -------- .../core/script_generations/skyvern_page.py | 86 ++++++++++++++++--- .../script_generations/workflow_wrappers.py | 3 +- skyvern/forge/sdk/core/skyvern_context.py | 1 + skyvern/services/script_service.py | 46 +++++----- 7 files changed, 120 insertions(+), 82 deletions(-) delete mode 100644 skyvern/core/script_generations/script_run_context_manager.py diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index ab14134c..f74ad07d 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -126,12 +126,12 @@ def _generate_text_call(text_value: str, intention: str, parameter_key: str) -> last_line=cst.SimpleWhitespace(DOUBLE_INDENT), ), args=[ - # First positional argument: context.generated_parameters['parameter_key'] + # First positional argument: context.parameters['parameter_key'] cst.Arg( value=cst.Subscript( value=cst.Attribute( value=cst.Name("context"), - attr=cst.Name("generated_parameters"), + attr=cst.Name("parameters"), ), slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))], ), @@ -247,20 +247,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst. ) if method in ["type", "fill"]: - # Get intention from action - intention = act.get("intention") or act.get("reasoning") or "" - - # Use generate_text call if field_name is available, otherwise fallback to direct value + # Use context.parameters if field_name is available, otherwise fallback to direct value if act.get("field_name"): - text_value = _generate_text_call( - text_value=act["text"], intention=intention, parameter_key=act["field_name"] + text_value = cst.Subscript( + value=cst.Attribute( + value=cst.Name("context"), + attr=cst.Name("parameters"), + ), + slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))], ) else: text_value = _value(act["text"]) args.append( cst.Arg( - keyword=cst.Name("text"), + keyword=cst.Name("value"), value=text_value, whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, @@ -268,6 +269,16 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst. ), ) ) + args.append( + cst.Arg( + keyword=cst.Name("ai_adapt_value"), + value=cst.Name("True"), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ) + ) elif method == "select_option": args.append( cst.Arg( diff --git a/skyvern/core/script_generations/run_initializer.py b/skyvern/core/script_generations/run_initializer.py index 724d6a2f..e8d2fb18 100644 --- a/skyvern/core/script_generations/run_initializer.py +++ b/skyvern/core/script_generations/run_initializer.py @@ -2,8 +2,7 @@ from typing import Any from pydantic import BaseModel -from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager -from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage +from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage, script_run_context_manager from skyvern.forge import app from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameterType diff --git a/skyvern/core/script_generations/script_run_context_manager.py b/skyvern/core/script_generations/script_run_context_manager.py deleted file mode 100644 index 3c96dcbd..00000000 --- a/skyvern/core/script_generations/script_run_context_manager.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Callable - -from skyvern.core.script_generations.skyvern_page import RunContext - - -class ScriptRunContextManager: - """ - Manages the run context for code runs. - """ - - def __init__(self) -> None: - # self.run_contexts: dict[str, RunContext] = {} - self.run_context: RunContext | None = None - self.cached_fns: dict[str, Callable] = {} - - def get_run_context(self) -> RunContext | None: - return self.run_context - - def set_run_context(self, run_context: RunContext) -> None: - self.run_context = run_context - - def ensure_run_context(self) -> RunContext: - if not self.run_context: - raise Exception("Run context not found") - return self.run_context - - def set_cached_fn(self, cache_key: str, fn: Callable) -> None: - self.cached_fns[cache_key] = fn - - def get_cached_fn(self, cache_key: str) -> Callable | None: - return self.cached_fns.get(cache_key) - - -script_run_context_manager = ScriptRunContextManager() diff --git a/skyvern/core/script_generations/skyvern_page.py b/skyvern/core/script_generations/skyvern_page.py index 25c87a52..b587d6da 100644 --- a/skyvern/core/script_generations/skyvern_page.py +++ b/skyvern/core/script_generations/skyvern_page.py @@ -8,6 +8,7 @@ from datetime import datetime, timezone from enum import StrEnum from typing import Any, Callable, Literal +import structlog from playwright.async_api import Page from skyvern.config import settings @@ -24,6 +25,8 @@ from skyvern.webeye.actions.actions import Action, ActionStatus, ExtractAction, from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website +LOG = structlog.get_logger() + class Driver(StrEnum): PLAYWRIGHT = "playwright" @@ -196,7 +199,8 @@ class SkyvernPage: # Create action record. TODO: store more action fields kwargs = kwargs or {} - text = kwargs.get("text") + # we're using "value" instead of "text" for input text actions interface + text = kwargs.get("value", "") option_value = kwargs.get("option") select_option = SelectOption(value=option_value) if option_value else None response: str | None = kwargs.get("response") @@ -314,7 +318,7 @@ class SkyvernPage: current_url=self.page.url, elements=element_tree, local_datetime=datetime.now(context.tz_info or datetime.now().astimezone().tzinfo).isoformat(), - user_context=getattr(context, "prompt", None), + # user_context=getattr(context, "prompt", None), ) json_response = await app.SINGLE_CLICK_AGENT_LLM_API_HANDLER( prompt=single_click_prompt, @@ -334,28 +338,31 @@ class SkyvernPage: async def fill( self, xpath: str, - text: str, + value: str, + ai_adapt_value: bool = False, intention: str | None = None, data: str | dict[str, Any] | None = None, timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS, ) -> None: - await self._input_text(xpath, text, intention, data, timeout) + await self._input_text(xpath, value, ai_adapt_value, intention, data, timeout) @action_wrap(ActionType.INPUT_TEXT) async def type( self, xpath: str, - text: str, + value: str, + ai_adapt_value: bool = False, intention: str | None = None, data: str | dict[str, Any] | None = None, timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS, ) -> None: - await self._input_text(xpath, text, intention, data, timeout) + await self._input_text(xpath, value, ai_adapt_value, intention, data, timeout) async def _input_text( self, xpath: str, - text: str, + value: str, + ai_adapt_value: bool = False, intention: str | None = None, data: str | dict[str, Any] | None = None, timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS, @@ -372,11 +379,33 @@ class SkyvernPage: """ # format the text with the actual value of the parameter if it's a secret when running a workflow context = skyvern_context.current() + value = value or "" if context and context.workflow_run_id: - text = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, text) + value = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, value) + + if ai_adapt_value and intention: + try: + prompt = context.prompt if context else None + # Build the element tree of the current page for the prompt + # clean up empty data values + data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "") + payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "") + script_generation_input_text_prompt = prompt_engine.load_prompt( + template="script-generation-input-text-generatiion", + intention=intention, + data=payload_str, + goal=prompt, + ) + json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER( + prompt=script_generation_input_text_prompt, + prompt_name="script-generation-input-text-generatiion", + ) + value = json_response.get("answer", value) + except Exception: + LOG.exception(f"Failed to adapt value for input text action on xpath={xpath}, value={value}") locator = self.page.locator(f"xpath={xpath}") - await handler_utils.input_sequentially(locator, text, timeout=timeout) + await handler_utils.input_sequentially(locator, value, timeout=timeout) @action_wrap(ActionType.UPLOAD_FILE) async def upload_file( @@ -542,11 +571,13 @@ class RunContext: self.original_parameters = parameters self.generated_parameters = generated_parameters self.parameters = copy.deepcopy(parameters) - # if generated_parameters: - # self.parameters.update(generated_parameters) + if generated_parameters: + # hydrate the generated parameter fields in the run context parameters + for key, value in generated_parameters.items(): + if key not in self.parameters: + self.parameters[key] = value self.page = page self.trace: list[ActionCall] = [] - self.prompt: str | None = None async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, parameter: str) -> Any: @@ -560,3 +591,34 @@ async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, paramet workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id) secret_value = workflow_run_context.get_original_secret_value_or_none(parameter) return secret_value if secret_value is not None else parameter + + +class ScriptRunContextManager: + """ + Manages the run context for code runs. + """ + + def __init__(self) -> None: + # self.run_contexts: dict[str, RunContext] = {} + self.run_context: RunContext | None = None + self.cached_fns: dict[str, Callable] = {} + + def get_run_context(self) -> RunContext | None: + return self.run_context + + def set_run_context(self, run_context: RunContext) -> None: + self.run_context = run_context + + def ensure_run_context(self) -> RunContext: + if not self.run_context: + raise Exception("Run context not found") + return self.run_context + + def set_cached_fn(self, cache_key: str, fn: Callable) -> None: + self.cached_fns[cache_key] = fn + + def get_cached_fn(self, cache_key: str) -> Callable | None: + return self.cached_fns.get(cache_key) + + +script_run_context_manager = ScriptRunContextManager() diff --git a/skyvern/core/script_generations/workflow_wrappers.py b/skyvern/core/script_generations/workflow_wrappers.py index db4b792b..e7bb0c6e 100644 --- a/skyvern/core/script_generations/workflow_wrappers.py +++ b/skyvern/core/script_generations/workflow_wrappers.py @@ -1,7 +1,6 @@ from typing import Any, Callable -from skyvern import RunContext, SkyvernPage -from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager +from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage, script_run_context_manager # Build a dummy workflow decorator diff --git a/skyvern/forge/sdk/core/skyvern_context.py b/skyvern/forge/sdk/core/skyvern_context.py index bc602729..298f600b 100644 --- a/skyvern/forge/sdk/core/skyvern_context.py +++ b/skyvern/forge/sdk/core/skyvern_context.py @@ -30,6 +30,7 @@ class SkyvernContext: script_id: str | None = None script_revision_id: str | None = None action_order: int = 0 + prompt: str | None = None def __repr__(self) -> str: return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})" diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index 7c15a399..ed963360 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -16,7 +16,7 @@ from skyvern.config import settings from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block -from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager +from skyvern.core.script_generations.skyvern_page import script_run_context_manager from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound from skyvern.forge import app from skyvern.forge.prompts import prompt_engine @@ -942,8 +942,8 @@ async def run_task( url=url, ) # set the prompt in the RunContext - run_context = script_run_context_manager.ensure_run_context() - run_context.prompt = prompt + context = skyvern_context.ensure_context() + context.prompt = prompt if cache_key: try: @@ -972,7 +972,7 @@ async def run_task( ) finally: # clear the prompt in the RunContext - run_context.prompt = None + context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -984,7 +984,7 @@ async def run_task( step_status=StepStatus.failed, failure_reason="Cache key is required", ) - run_context.prompt = None + context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -1001,8 +1001,8 @@ async def download( url=url, ) # set the prompt in the RunContext - run_context = script_run_context_manager.ensure_run_context() - run_context.prompt = prompt + context = skyvern_context.ensure_context() + context.prompt = prompt if cache_key: try: @@ -1031,7 +1031,7 @@ async def download( workflow_run_block_id=workflow_run_block_id, ) finally: - run_context.prompt = None + context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -1043,7 +1043,7 @@ async def download( step_status=StepStatus.failed, failure_reason="Cache key is required", ) - run_context.prompt = None + context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -1060,8 +1060,8 @@ async def action( url=url, ) # set the prompt in the RunContext - run_context = script_run_context_manager.ensure_run_context() - run_context.prompt = prompt + context = skyvern_context.ensure_context() + context.prompt = prompt if cache_key: try: @@ -1089,7 +1089,7 @@ async def action( workflow_run_block_id=workflow_run_block_id, ) finally: - run_context.prompt = None + context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -1101,7 +1101,7 @@ async def action( step_status=StepStatus.failed, failure_reason="Cache key is required", ) - run_context.prompt = None + context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -1118,8 +1118,8 @@ async def login( url=url, ) # set the prompt in the RunContext - run_context = script_run_context_manager.ensure_run_context() - run_context.prompt = prompt + context = skyvern_context.ensure_context() + context.prompt = prompt if cache_key: try: @@ -1147,7 +1147,7 @@ async def login( workflow_run_block_id=workflow_run_block_id, ) finally: - run_context.prompt = None + context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -1159,7 +1159,7 @@ async def login( step_status=StepStatus.failed, failure_reason="Cache key is required", ) - run_context.prompt = None + context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -1178,8 +1178,8 @@ async def extract( url=url, ) # set the prompt in the RunContext - run_context = script_run_context_manager.ensure_run_context() - run_context.prompt = prompt + context = skyvern_context.ensure_context() + context.prompt = prompt output: dict[str, Any] | list | str | None = None if cache_key: @@ -1213,7 +1213,7 @@ async def extract( ) raise finally: - run_context.prompt = None + context.prompt = None else: if workflow_run_block_id: await _update_workflow_block( @@ -1225,7 +1225,7 @@ async def extract( step_status=StepStatus.failed, failure_reason="Cache key is required", ) - run_context.prompt = None + context.prompt = None raise Exception("Cache key is required to run task block in a script") @@ -1296,8 +1296,8 @@ async def generate_text( new_text = text or "" if intention and data: try: - run_context = script_run_context_manager.ensure_run_context() - prompt = run_context.prompt + context = skyvern_context.ensure_context() + prompt = context.prompt # Build the element tree of the current page for the prompt payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "") script_generation_input_text_prompt = prompt_engine.load_prompt(