batch input field processing for input actions when caching and running workflows with code (#4250)
This commit is contained in:
@@ -61,6 +61,8 @@ from skyvern.schemas.scripts import (
|
||||
ScriptStatus,
|
||||
)
|
||||
from skyvern.schemas.workflows import BlockResult, BlockStatus, BlockType, FileStorageType, FileType
|
||||
from skyvern.webeye.actions.action_types import ActionType
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
from skyvern.webeye.scraper.scraped_page import ElementTreeFormat
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
@@ -352,6 +354,7 @@ async def execute_script(
|
||||
workflow_run_id=workflow_run_id,
|
||||
browser_session_id=browser_session_id,
|
||||
script_id=script_id,
|
||||
script_revision_id=script.script_revision_id,
|
||||
)
|
||||
else:
|
||||
# Execute synchronously
|
||||
@@ -362,6 +365,8 @@ async def execute_script(
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
browser_session_id=browser_session_id,
|
||||
script_id=script_id,
|
||||
script_revision_id=script.script_revision_id,
|
||||
)
|
||||
else:
|
||||
LOG.error("Script main.py not found", script_path=script_path, script_id=script_id)
|
||||
@@ -686,6 +691,162 @@ async def _run_cached_function(cached_fn: Callable) -> Any:
|
||||
return await cached_fn(page=run_context.page, context=run_context)
|
||||
|
||||
|
||||
def _determine_action_ai_mode(
|
||||
action: Action,
|
||||
merged_value: str | None,
|
||||
) -> str:
|
||||
"""
|
||||
Decide whether to run an input/select action in proactive or fallback mode.
|
||||
"""
|
||||
if action.has_mini_agent:
|
||||
return "proactive"
|
||||
# context = action.input_or_select_context
|
||||
# if isinstance(context, dict) and any(
|
||||
# context.get(flag) for flag in ("is_location_input", "is_date_related", "date_format")
|
||||
# ):
|
||||
# return "proactive"
|
||||
# if getattr(action, "totp_code_required", False):
|
||||
# return "proactive"
|
||||
if action.totp_timing_info and action.totp_timing_info.get("is_totp_sequence"):
|
||||
return "proactive"
|
||||
if merged_value and str(merged_value).strip():
|
||||
return "fallback"
|
||||
return "proactive"
|
||||
|
||||
|
||||
def _clear_cached_block_overrides(cache_key: str) -> None:
|
||||
context = skyvern_context.current()
|
||||
if not context:
|
||||
return
|
||||
context.action_ai_overrides.pop(cache_key, None)
|
||||
context.action_counters.pop(cache_key, None)
|
||||
|
||||
|
||||
async def _prepare_cached_block_inputs(cache_key: str, prompt: str | None, step_id: str | None = None) -> None:
|
||||
"""
|
||||
Fetch merged LLM inputs for a cached block and seed action-level AI overrides/parameters.
|
||||
"""
|
||||
context = skyvern_context.current()
|
||||
if not context or not context.organization_id or not context.script_revision_id:
|
||||
return
|
||||
|
||||
try:
|
||||
script_block = await app.DATABASE.get_script_block_by_label(
|
||||
organization_id=context.organization_id,
|
||||
script_revision_id=context.script_revision_id,
|
||||
script_block_label=cache_key,
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
input_fields: list[str] = []
|
||||
workflow_run_block_id = None
|
||||
if script_block:
|
||||
input_fields = script_block.input_fields or []
|
||||
workflow_run_block_id = script_block.workflow_run_block_id
|
||||
|
||||
if not input_fields or not workflow_run_block_id:
|
||||
return
|
||||
|
||||
try:
|
||||
source_block = await app.DATABASE.get_workflow_run_block(
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=context.organization_id,
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
task_id = source_block.task_id
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
try:
|
||||
# actios are ordered by created_at
|
||||
actions = await app.DATABASE.get_task_actions_hydrated(task_id=task_id, organization_id=context.organization_id)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
input_actions = [action for action in actions if action.action_type in {ActionType.INPUT_TEXT}]
|
||||
# TODO: how to support select_option actions?
|
||||
# input_actions = [
|
||||
# action for action in actions if action.action_type in {ActionType.INPUT_TEXT, ActionType.SELECT_OPTION}
|
||||
# ]
|
||||
|
||||
if not input_actions:
|
||||
return
|
||||
|
||||
# Map actions to field names using stored field_name when present; otherwise consume in order from input_fields.
|
||||
field_iter = iter(input_fields)
|
||||
action_entries: list[tuple[Action, str | None]] = []
|
||||
for action in input_actions:
|
||||
field_name = None
|
||||
try:
|
||||
field_name = next(field_iter, None)
|
||||
except StopIteration:
|
||||
field_name = None
|
||||
action_entries.append((action, field_name))
|
||||
|
||||
merged_values: dict[str, Any] = {}
|
||||
run_context = script_run_context_manager.get_run_context()
|
||||
if not run_context:
|
||||
return
|
||||
|
||||
try:
|
||||
parameters = {key: str(value) for key, value in run_context.parameters.items() if value}
|
||||
serialized_params = json.dumps(parameters)
|
||||
field_prompts = []
|
||||
for action, field_name in action_entries:
|
||||
if not field_name:
|
||||
continue
|
||||
prompt_text = action.intention or action.reasoning or ""
|
||||
if action.input_or_select_context and action.input_or_select_context.intention:
|
||||
prompt_text = action.input_or_select_context.intention
|
||||
field_prompts.append({"name": field_name, "prompt": prompt_text})
|
||||
|
||||
if field_prompts:
|
||||
merged_prompt = (
|
||||
"You are helping fill web form fields for a workflow block.\n"
|
||||
f"Block prompt/context:\n{prompt or ''}\n\n"
|
||||
f"Workflow parameters (as JSON):\n{serialized_params}\n\n"
|
||||
"Return a JSON object mapping field_name -> value for the following fields.\n"
|
||||
"Leave value empty string if it cannot be determined.\n"
|
||||
f"Fields:\n{json.dumps(field_prompts)}"
|
||||
)
|
||||
step = None
|
||||
if step_id:
|
||||
step = await app.DATABASE.get_step(step_id=step_id, organization_id=context.organization_id)
|
||||
llm_response = await app.SCRIPT_GENERATION_LLM_API_HANDLER(
|
||||
prompt=merged_prompt,
|
||||
prompt_name="merged-block-inputs",
|
||||
step=step,
|
||||
)
|
||||
if isinstance(llm_response, dict):
|
||||
merged_values = llm_response
|
||||
elif isinstance(llm_response, str):
|
||||
try:
|
||||
merged_values = json.loads(llm_response)
|
||||
except Exception:
|
||||
merged_values = {}
|
||||
else:
|
||||
merged_values = {}
|
||||
except Exception:
|
||||
merged_values = {}
|
||||
|
||||
overrides: dict[int, str] = {}
|
||||
for idx, (action, field_name) in enumerate(action_entries, start=1):
|
||||
merged_value = merged_values.get(field_name, "") if field_name else ""
|
||||
ai_mode = _determine_action_ai_mode(action, merged_value)
|
||||
overrides[idx] = ai_mode
|
||||
|
||||
if ai_mode == "fallback" and field_name and isinstance(merged_value, str):
|
||||
# Seed the run context parameters with merged values for cached execution.
|
||||
run_context.parameters[field_name] = merged_value
|
||||
|
||||
# if overrides:
|
||||
# context.action_ai_overrides[cache_key] = overrides
|
||||
# context.action_counters[cache_key] = 0
|
||||
|
||||
|
||||
async def _detect_user_defined_errors(
|
||||
task: Task,
|
||||
step: Step,
|
||||
@@ -1199,6 +1360,7 @@ async def _regenerate_script_block_after_ai_fallback(
|
||||
block_label=existing_block.script_block_label,
|
||||
workflow_run_id=existing_block.workflow_run_id,
|
||||
workflow_run_block_id=existing_block.workflow_run_block_id,
|
||||
input_fields=existing_block.input_fields,
|
||||
)
|
||||
block_file_content_bytes = (
|
||||
block_file_content if isinstance(block_file_content, bytes) else block_file_content.encode("utf-8")
|
||||
@@ -1357,6 +1519,7 @@ async def run_task(
|
||||
context = skyvern_context.ensure_context()
|
||||
context.prompt = prompt
|
||||
try:
|
||||
await _prepare_cached_block_inputs(cache_key, prompt)
|
||||
output = await _run_cached_function(cached_fn)
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
@@ -1389,6 +1552,7 @@ async def run_task(
|
||||
finally:
|
||||
# clear the prompt in the RunContext
|
||||
context.prompt = None
|
||||
_clear_cached_block_overrides(cache_key)
|
||||
else:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label)
|
||||
task_block = NavigationBlock(
|
||||
@@ -1444,6 +1608,7 @@ async def download(
|
||||
context.prompt = prompt
|
||||
|
||||
try:
|
||||
await _prepare_cached_block_inputs(cache_key, prompt)
|
||||
await _run_cached_function(cached_fn)
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
@@ -1471,6 +1636,7 @@ async def download(
|
||||
)
|
||||
finally:
|
||||
context.prompt = None
|
||||
_clear_cached_block_overrides(cache_key)
|
||||
else:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label)
|
||||
file_download_block = FileDownloadBlock(
|
||||
@@ -1525,6 +1691,7 @@ async def action(
|
||||
context.prompt = prompt
|
||||
|
||||
try:
|
||||
await _prepare_cached_block_inputs(cache_key, prompt)
|
||||
await _run_cached_function(cached_fn)
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
@@ -1553,6 +1720,7 @@ async def action(
|
||||
)
|
||||
finally:
|
||||
context.prompt = None
|
||||
_clear_cached_block_overrides(cache_key)
|
||||
else:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label)
|
||||
action_block = ActionBlock(
|
||||
@@ -1609,6 +1777,7 @@ async def login(
|
||||
context = skyvern_context.ensure_context()
|
||||
context.prompt = prompt
|
||||
try:
|
||||
await _prepare_cached_block_inputs(cache_key, prompt)
|
||||
await _run_cached_function(cached_fn)
|
||||
|
||||
# Update block status to completed if workflow block was created
|
||||
@@ -1637,6 +1806,7 @@ async def login(
|
||||
)
|
||||
finally:
|
||||
context.prompt = None
|
||||
_clear_cached_block_overrides(cache_key)
|
||||
else:
|
||||
block_validation_output = await _validate_and_get_output_parameter(label)
|
||||
login_block = LoginBlock(
|
||||
@@ -1804,13 +1974,23 @@ async def run_script(
|
||||
organization_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
script_id: str | None = None,
|
||||
script_revision_id: str | None = None,
|
||||
) -> None:
|
||||
# register the script run
|
||||
context = skyvern_context.current()
|
||||
if not context:
|
||||
context = skyvern_context.ensure_context()
|
||||
skyvern_context.set(skyvern_context.SkyvernContext())
|
||||
context = skyvern_context.SkyvernContext()
|
||||
skyvern_context.set(context)
|
||||
|
||||
context.browser_session_id = browser_session_id
|
||||
if organization_id:
|
||||
context.organization_id = organization_id
|
||||
if script_id:
|
||||
context.script_id = script_id
|
||||
if script_revision_id:
|
||||
context.script_revision_id = script_revision_id
|
||||
|
||||
if workflow_run_id and organization_id:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id, organization_id=organization_id
|
||||
|
||||
Reference in New Issue
Block a user