ai_adapt_value for text input (#3354)

This commit is contained in:
Shuchang Zheng
2025-09-03 16:44:52 -07:00
committed by GitHub
parent 32771bdd19
commit 55d847461e
7 changed files with 120 additions and 82 deletions

View File

@@ -126,12 +126,12 @@ def _generate_text_call(text_value: str, intention: str, parameter_key: str) ->
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
),
args=[
# First positional argument: context.generated_parameters['parameter_key']
# First positional argument: context.parameters['parameter_key']
cst.Arg(
value=cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("generated_parameters"),
attr=cst.Name("parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))],
),
@@ -247,20 +247,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
)
if method in ["type", "fill"]:
# Get intention from action
intention = act.get("intention") or act.get("reasoning") or ""
# Use generate_text call if field_name is available, otherwise fallback to direct value
# Use context.parameters if field_name is available, otherwise fallback to direct value
if act.get("field_name"):
text_value = _generate_text_call(
text_value=act["text"], intention=intention, parameter_key=act["field_name"]
text_value = cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
)
else:
text_value = _value(act["text"])
args.append(
cst.Arg(
keyword=cst.Name("text"),
keyword=cst.Name("value"),
value=text_value,
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
@@ -268,6 +269,16 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
),
)
)
args.append(
cst.Arg(
keyword=cst.Name("ai_adapt_value"),
value=cst.Name("True"),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
elif method == "select_option":
args.append(
cst.Arg(

View File

@@ -2,8 +2,7 @@ from typing import Any
from pydantic import BaseModel
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage, script_run_context_manager
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameterType

View File

@@ -1,34 +0,0 @@
from typing import Callable
from skyvern.core.script_generations.skyvern_page import RunContext
class ScriptRunContextManager:
"""
Manages the run context for code runs.
"""
def __init__(self) -> None:
# self.run_contexts: dict[str, RunContext] = {}
self.run_context: RunContext | None = None
self.cached_fns: dict[str, Callable] = {}
def get_run_context(self) -> RunContext | None:
return self.run_context
def set_run_context(self, run_context: RunContext) -> None:
self.run_context = run_context
def ensure_run_context(self) -> RunContext:
if not self.run_context:
raise Exception("Run context not found")
return self.run_context
def set_cached_fn(self, cache_key: str, fn: Callable) -> None:
self.cached_fns[cache_key] = fn
def get_cached_fn(self, cache_key: str) -> Callable | None:
return self.cached_fns.get(cache_key)
script_run_context_manager = ScriptRunContextManager()

View File

@@ -8,6 +8,7 @@ from datetime import datetime, timezone
from enum import StrEnum
from typing import Any, Callable, Literal
import structlog
from playwright.async_api import Page
from skyvern.config import settings
@@ -24,6 +25,8 @@ from skyvern.webeye.actions.actions import Action, ActionStatus, ExtractAction,
from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website
LOG = structlog.get_logger()
class Driver(StrEnum):
PLAYWRIGHT = "playwright"
@@ -196,7 +199,8 @@ class SkyvernPage:
# Create action record. TODO: store more action fields
kwargs = kwargs or {}
text = kwargs.get("text")
# we're using "value" instead of "text" for input text actions interface
text = kwargs.get("value", "")
option_value = kwargs.get("option")
select_option = SelectOption(value=option_value) if option_value else None
response: str | None = kwargs.get("response")
@@ -314,7 +318,7 @@ class SkyvernPage:
current_url=self.page.url,
elements=element_tree,
local_datetime=datetime.now(context.tz_info or datetime.now().astimezone().tzinfo).isoformat(),
user_context=getattr(context, "prompt", None),
# user_context=getattr(context, "prompt", None),
)
json_response = await app.SINGLE_CLICK_AGENT_LLM_API_HANDLER(
prompt=single_click_prompt,
@@ -334,28 +338,31 @@ class SkyvernPage:
async def fill(
self,
xpath: str,
text: str,
value: str,
ai_adapt_value: bool = False,
intention: str | None = None,
data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
) -> None:
await self._input_text(xpath, text, intention, data, timeout)
await self._input_text(xpath, value, ai_adapt_value, intention, data, timeout)
@action_wrap(ActionType.INPUT_TEXT)
async def type(
self,
xpath: str,
text: str,
value: str,
ai_adapt_value: bool = False,
intention: str | None = None,
data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
) -> None:
await self._input_text(xpath, text, intention, data, timeout)
await self._input_text(xpath, value, ai_adapt_value, intention, data, timeout)
async def _input_text(
self,
xpath: str,
text: str,
value: str,
ai_adapt_value: bool = False,
intention: str | None = None,
data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
@@ -372,11 +379,33 @@ class SkyvernPage:
"""
# format the text with the actual value of the parameter if it's a secret when running a workflow
context = skyvern_context.current()
value = value or ""
if context and context.workflow_run_id:
text = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, text)
value = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, value)
if ai_adapt_value and intention:
try:
prompt = context.prompt if context else None
# Build the element tree of the current page for the prompt
# clean up empty data values
data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "")
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt(
template="script-generation-input-text-generatiion",
intention=intention,
data=payload_str,
goal=prompt,
)
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
prompt=script_generation_input_text_prompt,
prompt_name="script-generation-input-text-generatiion",
)
value = json_response.get("answer", value)
except Exception:
LOG.exception(f"Failed to adapt value for input text action on xpath={xpath}, value={value}")
locator = self.page.locator(f"xpath={xpath}")
await handler_utils.input_sequentially(locator, text, timeout=timeout)
await handler_utils.input_sequentially(locator, value, timeout=timeout)
@action_wrap(ActionType.UPLOAD_FILE)
async def upload_file(
@@ -542,11 +571,13 @@ class RunContext:
self.original_parameters = parameters
self.generated_parameters = generated_parameters
self.parameters = copy.deepcopy(parameters)
# if generated_parameters:
# self.parameters.update(generated_parameters)
if generated_parameters:
# hydrate the generated parameter fields in the run context parameters
for key, value in generated_parameters.items():
if key not in self.parameters:
self.parameters[key] = value
self.page = page
self.trace: list[ActionCall] = []
self.prompt: str | None = None
async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, parameter: str) -> Any:
@@ -560,3 +591,34 @@ async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, paramet
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
secret_value = workflow_run_context.get_original_secret_value_or_none(parameter)
return secret_value if secret_value is not None else parameter
class ScriptRunContextManager:
"""
Manages the run context for code runs.
"""
def __init__(self) -> None:
# self.run_contexts: dict[str, RunContext] = {}
self.run_context: RunContext | None = None
self.cached_fns: dict[str, Callable] = {}
def get_run_context(self) -> RunContext | None:
return self.run_context
def set_run_context(self, run_context: RunContext) -> None:
self.run_context = run_context
def ensure_run_context(self) -> RunContext:
if not self.run_context:
raise Exception("Run context not found")
return self.run_context
def set_cached_fn(self, cache_key: str, fn: Callable) -> None:
self.cached_fns[cache_key] = fn
def get_cached_fn(self, cache_key: str) -> Callable | None:
return self.cached_fns.get(cache_key)
script_run_context_manager = ScriptRunContextManager()

View File

@@ -1,7 +1,6 @@
from typing import Any, Callable
from skyvern import RunContext, SkyvernPage
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage, script_run_context_manager
# Build a dummy workflow decorator

View File

@@ -30,6 +30,7 @@ class SkyvernContext:
script_id: str | None = None
script_revision_id: str | None = None
action_order: int = 0
prompt: str | None = None
def __repr__(self) -> str:
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})"

View File

@@ -16,7 +16,7 @@ from skyvern.config import settings
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
from skyvern.core.script_generations.skyvern_page import script_run_context_manager
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
@@ -942,8 +942,8 @@ async def run_task(
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
context = skyvern_context.ensure_context()
context.prompt = prompt
if cache_key:
try:
@@ -972,7 +972,7 @@ async def run_task(
)
finally:
# clear the prompt in the RunContext
run_context.prompt = None
context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -984,7 +984,7 @@ async def run_task(
step_status=StepStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -1001,8 +1001,8 @@ async def download(
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
context = skyvern_context.ensure_context()
context.prompt = prompt
if cache_key:
try:
@@ -1031,7 +1031,7 @@ async def download(
workflow_run_block_id=workflow_run_block_id,
)
finally:
run_context.prompt = None
context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -1043,7 +1043,7 @@ async def download(
step_status=StepStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -1060,8 +1060,8 @@ async def action(
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
context = skyvern_context.ensure_context()
context.prompt = prompt
if cache_key:
try:
@@ -1089,7 +1089,7 @@ async def action(
workflow_run_block_id=workflow_run_block_id,
)
finally:
run_context.prompt = None
context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -1101,7 +1101,7 @@ async def action(
step_status=StepStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -1118,8 +1118,8 @@ async def login(
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
context = skyvern_context.ensure_context()
context.prompt = prompt
if cache_key:
try:
@@ -1147,7 +1147,7 @@ async def login(
workflow_run_block_id=workflow_run_block_id,
)
finally:
run_context.prompt = None
context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -1159,7 +1159,7 @@ async def login(
step_status=StepStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -1178,8 +1178,8 @@ async def extract(
url=url,
)
# set the prompt in the RunContext
run_context = script_run_context_manager.ensure_run_context()
run_context.prompt = prompt
context = skyvern_context.ensure_context()
context.prompt = prompt
output: dict[str, Any] | list | str | None = None
if cache_key:
@@ -1213,7 +1213,7 @@ async def extract(
)
raise
finally:
run_context.prompt = None
context.prompt = None
else:
if workflow_run_block_id:
await _update_workflow_block(
@@ -1225,7 +1225,7 @@ async def extract(
step_status=StepStatus.failed,
failure_reason="Cache key is required",
)
run_context.prompt = None
context.prompt = None
raise Exception("Cache key is required to run task block in a script")
@@ -1296,8 +1296,8 @@ async def generate_text(
new_text = text or ""
if intention and data:
try:
run_context = script_run_context_manager.ensure_run_context()
prompt = run_context.prompt
context = skyvern_context.ensure_context()
prompt = context.prompt
# Build the element tree of the current page for the prompt
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt(