ai_adapt_value for text input (#3354)
This commit is contained in:
@@ -126,12 +126,12 @@ def _generate_text_call(text_value: str, intention: str, parameter_key: str) ->
|
||||
last_line=cst.SimpleWhitespace(DOUBLE_INDENT),
|
||||
),
|
||||
args=[
|
||||
# First positional argument: context.generated_parameters['parameter_key']
|
||||
# First positional argument: context.parameters['parameter_key']
|
||||
cst.Arg(
|
||||
value=cst.Subscript(
|
||||
value=cst.Attribute(
|
||||
value=cst.Name("context"),
|
||||
attr=cst.Name("generated_parameters"),
|
||||
attr=cst.Name("parameters"),
|
||||
),
|
||||
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(parameter_key)))],
|
||||
),
|
||||
@@ -247,20 +247,21 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
|
||||
)
|
||||
|
||||
if method in ["type", "fill"]:
|
||||
# Get intention from action
|
||||
intention = act.get("intention") or act.get("reasoning") or ""
|
||||
|
||||
# Use generate_text call if field_name is available, otherwise fallback to direct value
|
||||
# Use context.parameters if field_name is available, otherwise fallback to direct value
|
||||
if act.get("field_name"):
|
||||
text_value = _generate_text_call(
|
||||
text_value=act["text"], intention=intention, parameter_key=act["field_name"]
|
||||
text_value = cst.Subscript(
|
||||
value=cst.Attribute(
|
||||
value=cst.Name("context"),
|
||||
attr=cst.Name("parameters"),
|
||||
),
|
||||
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
|
||||
)
|
||||
else:
|
||||
text_value = _value(act["text"])
|
||||
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("text"),
|
||||
keyword=cst.Name("value"),
|
||||
value=text_value,
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
@@ -268,6 +269,16 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst.
|
||||
),
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
cst.Arg(
|
||||
keyword=cst.Name("ai_adapt_value"),
|
||||
value=cst.Name("True"),
|
||||
whitespace_after_arg=cst.ParenthesizedWhitespace(
|
||||
indent=True,
|
||||
last_line=cst.SimpleWhitespace(INDENT),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif method == "select_option":
|
||||
args.append(
|
||||
cst.Arg(
|
||||
|
||||
@@ -2,8 +2,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
|
||||
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage
|
||||
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage, script_run_context_manager
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.workflow.models.parameter import WorkflowParameterType
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import Callable
|
||||
|
||||
from skyvern.core.script_generations.skyvern_page import RunContext
|
||||
|
||||
|
||||
class ScriptRunContextManager:
|
||||
"""
|
||||
Manages the run context for code runs.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# self.run_contexts: dict[str, RunContext] = {}
|
||||
self.run_context: RunContext | None = None
|
||||
self.cached_fns: dict[str, Callable] = {}
|
||||
|
||||
def get_run_context(self) -> RunContext | None:
|
||||
return self.run_context
|
||||
|
||||
def set_run_context(self, run_context: RunContext) -> None:
|
||||
self.run_context = run_context
|
||||
|
||||
def ensure_run_context(self) -> RunContext:
|
||||
if not self.run_context:
|
||||
raise Exception("Run context not found")
|
||||
return self.run_context
|
||||
|
||||
def set_cached_fn(self, cache_key: str, fn: Callable) -> None:
|
||||
self.cached_fns[cache_key] = fn
|
||||
|
||||
def get_cached_fn(self, cache_key: str) -> Callable | None:
|
||||
return self.cached_fns.get(cache_key)
|
||||
|
||||
|
||||
script_run_context_manager = ScriptRunContextManager()
|
||||
@@ -8,6 +8,7 @@ from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import Page
|
||||
|
||||
from skyvern.config import settings
|
||||
@@ -24,6 +25,8 @@ from skyvern.webeye.actions.actions import Action, ActionStatus, ExtractAction,
|
||||
from skyvern.webeye.browser_factory import BrowserState
|
||||
from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class Driver(StrEnum):
|
||||
PLAYWRIGHT = "playwright"
|
||||
@@ -196,7 +199,8 @@ class SkyvernPage:
|
||||
|
||||
# Create action record. TODO: store more action fields
|
||||
kwargs = kwargs or {}
|
||||
text = kwargs.get("text")
|
||||
# we're using "value" instead of "text" for input text actions interface
|
||||
text = kwargs.get("value", "")
|
||||
option_value = kwargs.get("option")
|
||||
select_option = SelectOption(value=option_value) if option_value else None
|
||||
response: str | None = kwargs.get("response")
|
||||
@@ -314,7 +318,7 @@ class SkyvernPage:
|
||||
current_url=self.page.url,
|
||||
elements=element_tree,
|
||||
local_datetime=datetime.now(context.tz_info or datetime.now().astimezone().tzinfo).isoformat(),
|
||||
user_context=getattr(context, "prompt", None),
|
||||
# user_context=getattr(context, "prompt", None),
|
||||
)
|
||||
json_response = await app.SINGLE_CLICK_AGENT_LLM_API_HANDLER(
|
||||
prompt=single_click_prompt,
|
||||
@@ -334,28 +338,31 @@ class SkyvernPage:
|
||||
async def fill(
|
||||
self,
|
||||
xpath: str,
|
||||
text: str,
|
||||
value: str,
|
||||
ai_adapt_value: bool = False,
|
||||
intention: str | None = None,
|
||||
data: str | dict[str, Any] | None = None,
|
||||
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
|
||||
) -> None:
|
||||
await self._input_text(xpath, text, intention, data, timeout)
|
||||
await self._input_text(xpath, value, ai_adapt_value, intention, data, timeout)
|
||||
|
||||
@action_wrap(ActionType.INPUT_TEXT)
|
||||
async def type(
|
||||
self,
|
||||
xpath: str,
|
||||
text: str,
|
||||
value: str,
|
||||
ai_adapt_value: bool = False,
|
||||
intention: str | None = None,
|
||||
data: str | dict[str, Any] | None = None,
|
||||
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
|
||||
) -> None:
|
||||
await self._input_text(xpath, text, intention, data, timeout)
|
||||
await self._input_text(xpath, value, ai_adapt_value, intention, data, timeout)
|
||||
|
||||
async def _input_text(
|
||||
self,
|
||||
xpath: str,
|
||||
text: str,
|
||||
value: str,
|
||||
ai_adapt_value: bool = False,
|
||||
intention: str | None = None,
|
||||
data: str | dict[str, Any] | None = None,
|
||||
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
|
||||
@@ -372,11 +379,33 @@ class SkyvernPage:
|
||||
"""
|
||||
# format the text with the actual value of the parameter if it's a secret when running a workflow
|
||||
context = skyvern_context.current()
|
||||
value = value or ""
|
||||
if context and context.workflow_run_id:
|
||||
text = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, text)
|
||||
value = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, value)
|
||||
|
||||
if ai_adapt_value and intention:
|
||||
try:
|
||||
prompt = context.prompt if context else None
|
||||
# Build the element tree of the current page for the prompt
|
||||
# clean up empty data values
|
||||
data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "")
|
||||
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
|
||||
script_generation_input_text_prompt = prompt_engine.load_prompt(
|
||||
template="script-generation-input-text-generatiion",
|
||||
intention=intention,
|
||||
data=payload_str,
|
||||
goal=prompt,
|
||||
)
|
||||
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
|
||||
prompt=script_generation_input_text_prompt,
|
||||
prompt_name="script-generation-input-text-generatiion",
|
||||
)
|
||||
value = json_response.get("answer", value)
|
||||
except Exception:
|
||||
LOG.exception(f"Failed to adapt value for input text action on xpath={xpath}, value={value}")
|
||||
|
||||
locator = self.page.locator(f"xpath={xpath}")
|
||||
await handler_utils.input_sequentially(locator, text, timeout=timeout)
|
||||
await handler_utils.input_sequentially(locator, value, timeout=timeout)
|
||||
|
||||
@action_wrap(ActionType.UPLOAD_FILE)
|
||||
async def upload_file(
|
||||
@@ -542,11 +571,13 @@ class RunContext:
|
||||
self.original_parameters = parameters
|
||||
self.generated_parameters = generated_parameters
|
||||
self.parameters = copy.deepcopy(parameters)
|
||||
# if generated_parameters:
|
||||
# self.parameters.update(generated_parameters)
|
||||
if generated_parameters:
|
||||
# hydrate the generated parameter fields in the run context parameters
|
||||
for key, value in generated_parameters.items():
|
||||
if key not in self.parameters:
|
||||
self.parameters[key] = value
|
||||
self.page = page
|
||||
self.trace: list[ActionCall] = []
|
||||
self.prompt: str | None = None
|
||||
|
||||
|
||||
async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, parameter: str) -> Any:
|
||||
@@ -560,3 +591,34 @@ async def _get_actual_value_of_parameter_if_secret(workflow_run_id: str, paramet
|
||||
workflow_run_context = app.WORKFLOW_CONTEXT_MANAGER.get_workflow_run_context(workflow_run_id)
|
||||
secret_value = workflow_run_context.get_original_secret_value_or_none(parameter)
|
||||
return secret_value if secret_value is not None else parameter
|
||||
|
||||
|
||||
class ScriptRunContextManager:
|
||||
"""
|
||||
Manages the run context for code runs.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# self.run_contexts: dict[str, RunContext] = {}
|
||||
self.run_context: RunContext | None = None
|
||||
self.cached_fns: dict[str, Callable] = {}
|
||||
|
||||
def get_run_context(self) -> RunContext | None:
|
||||
return self.run_context
|
||||
|
||||
def set_run_context(self, run_context: RunContext) -> None:
|
||||
self.run_context = run_context
|
||||
|
||||
def ensure_run_context(self) -> RunContext:
|
||||
if not self.run_context:
|
||||
raise Exception("Run context not found")
|
||||
return self.run_context
|
||||
|
||||
def set_cached_fn(self, cache_key: str, fn: Callable) -> None:
|
||||
self.cached_fns[cache_key] = fn
|
||||
|
||||
def get_cached_fn(self, cache_key: str) -> Callable | None:
|
||||
return self.cached_fns.get(cache_key)
|
||||
|
||||
|
||||
script_run_context_manager = ScriptRunContextManager()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from skyvern import RunContext, SkyvernPage
|
||||
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
|
||||
from skyvern.core.script_generations.skyvern_page import RunContext, SkyvernPage, script_run_context_manager
|
||||
|
||||
|
||||
# Build a dummy workflow decorator
|
||||
|
||||
@@ -30,6 +30,7 @@ class SkyvernContext:
|
||||
script_id: str | None = None
|
||||
script_revision_id: str | None = None
|
||||
action_order: int = 0
|
||||
prompt: str | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})"
|
||||
|
||||
@@ -16,7 +16,7 @@ from skyvern.config import settings
|
||||
from skyvern.constants import GET_DOWNLOADED_FILES_TIMEOUT
|
||||
from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS
|
||||
from skyvern.core.script_generations.generate_script import _build_block_fn, create_script_block
|
||||
from skyvern.core.script_generations.script_run_context_manager import script_run_context_manager
|
||||
from skyvern.core.script_generations.skyvern_page import script_run_context_manager
|
||||
from skyvern.exceptions import ScriptNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
@@ -942,8 +942,8 @@ async def run_task(
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
context = skyvern_context.ensure_context()
|
||||
context.prompt = prompt
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
@@ -972,7 +972,7 @@ async def run_task(
|
||||
)
|
||||
finally:
|
||||
# clear the prompt in the RunContext
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -984,7 +984,7 @@ async def run_task(
|
||||
step_status=StepStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -1001,8 +1001,8 @@ async def download(
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
context = skyvern_context.ensure_context()
|
||||
context.prompt = prompt
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
@@ -1031,7 +1031,7 @@ async def download(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
)
|
||||
finally:
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -1043,7 +1043,7 @@ async def download(
|
||||
step_status=StepStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -1060,8 +1060,8 @@ async def action(
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
context = skyvern_context.ensure_context()
|
||||
context.prompt = prompt
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
@@ -1089,7 +1089,7 @@ async def action(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
)
|
||||
finally:
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -1101,7 +1101,7 @@ async def action(
|
||||
step_status=StepStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -1118,8 +1118,8 @@ async def login(
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
context = skyvern_context.ensure_context()
|
||||
context.prompt = prompt
|
||||
|
||||
if cache_key:
|
||||
try:
|
||||
@@ -1147,7 +1147,7 @@ async def login(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
)
|
||||
finally:
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -1159,7 +1159,7 @@ async def login(
|
||||
step_status=StepStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -1178,8 +1178,8 @@ async def extract(
|
||||
url=url,
|
||||
)
|
||||
# set the prompt in the RunContext
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
run_context.prompt = prompt
|
||||
context = skyvern_context.ensure_context()
|
||||
context.prompt = prompt
|
||||
output: dict[str, Any] | list | str | None = None
|
||||
|
||||
if cache_key:
|
||||
@@ -1213,7 +1213,7 @@ async def extract(
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
else:
|
||||
if workflow_run_block_id:
|
||||
await _update_workflow_block(
|
||||
@@ -1225,7 +1225,7 @@ async def extract(
|
||||
step_status=StepStatus.failed,
|
||||
failure_reason="Cache key is required",
|
||||
)
|
||||
run_context.prompt = None
|
||||
context.prompt = None
|
||||
raise Exception("Cache key is required to run task block in a script")
|
||||
|
||||
|
||||
@@ -1296,8 +1296,8 @@ async def generate_text(
|
||||
new_text = text or ""
|
||||
if intention and data:
|
||||
try:
|
||||
run_context = script_run_context_manager.ensure_run_context()
|
||||
prompt = run_context.prompt
|
||||
context = skyvern_context.ensure_context()
|
||||
prompt = context.prompt
|
||||
# Build the element tree of the current page for the prompt
|
||||
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
|
||||
script_generation_input_text_prompt = prompt_engine.load_prompt(
|
||||
|
||||
Reference in New Issue
Block a user