Smarter select_option & input_text actions (#3440)

This commit is contained in:
Shuchang Zheng
2025-09-15 13:16:34 -07:00
committed by GitHub
parent 6f212ff327
commit 6ee329866b
10 changed files with 300 additions and 105 deletions

View File

@@ -353,15 +353,69 @@ def _action_to_stmt(act: dict[str, Any], task: dict[str, Any], assign_to_output:
) )
) )
elif method == "select_option": elif method == "select_option":
option = act.get("option", {})
value = option.get("value")
if value:
if act.get("field_name"):
option_value = cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
)
else:
option_value = _value(value)
args.append(
cst.Arg(
keyword=cst.Name("value"),
value=option_value,
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
),
)
args.append(
cst.Arg(
keyword=cst.Name("ai_infer"),
value=cst.Name("True"),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
)
elif method == "upload_file":
if act.get("field_name"):
file_url_value = cst.Subscript(
value=cst.Attribute(
value=cst.Name("context"),
attr=cst.Name("parameters"),
),
slice=[cst.SubscriptElement(slice=cst.Index(value=_value(act["field_name"])))],
)
else:
file_url_value = _value(act["file_url"])
args.append( args.append(
cst.Arg( cst.Arg(
keyword=cst.Name("option"), keyword=cst.Name("files"),
value=_value(act["option"]["value"]), value=file_url_value,
whitespace_after_arg=cst.ParenthesizedWhitespace( whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True, indent=True,
last_line=cst.SimpleWhitespace(INDENT), last_line=cst.SimpleWhitespace(INDENT),
), ),
), )
)
args.append(
cst.Arg(
keyword=cst.Name("ai_infer"),
value=cst.Name("True"),
whitespace_after_arg=cst.ParenthesizedWhitespace(
indent=True,
last_line=cst.SimpleWhitespace(INDENT),
),
)
) )
elif method == "wait": elif method == "wait":
args.append( args.append(

View File

@@ -15,6 +15,7 @@ LOG = structlog.get_logger(__name__)
# Initialize prompt engine # Initialize prompt engine
prompt_engine = PromptEngine("skyvern") prompt_engine = PromptEngine("skyvern")
CUSTOM_FIELD_ACTIONS = [ActionType.INPUT_TEXT, ActionType.UPLOAD_FILE, ActionType.SELECT_OPTION]
class GeneratedFieldMapping(BaseModel): class GeneratedFieldMapping(BaseModel):
@@ -39,34 +40,45 @@ async def generate_workflow_parameters_schema(
- field_mappings: Dictionary mapping action indices to field names for hydration - field_mappings: Dictionary mapping action indices to field names for hydration
""" """
# Extract all input_text actions # Extract all input_text actions
input_actions = [] custom_field_actions = []
action_index_map = {} action_index_map = {}
action_counter = 1 action_counter = 1
for task_id, actions in actions_by_task.items(): for task_id, actions in actions_by_task.items():
for action in actions: for action in actions:
if action.get("action_type") == ActionType.INPUT_TEXT: action_type = action.get("action_type", "")
input_actions.append( if action_type not in CUSTOM_FIELD_ACTIONS:
{ continue
"text": action.get("text", ""),
"intention": action.get("intention", ""), value = ""
"task_id": task_id, if action_type == ActionType.INPUT_TEXT:
"action_id": action.get("action_id", ""), value = action.get("text", "")
} elif action_type == ActionType.UPLOAD_FILE:
) value = action.get("file_url", "")
action_index_map[f"action_index_{action_counter}"] = { elif action_type == ActionType.SELECT_OPTION:
value = action.get("option", "")
custom_field_actions.append(
{
"action_type": action_type,
"value": value,
"intention": action.get("intention", ""),
"task_id": task_id, "task_id": task_id,
"action_id": action.get("action_id", ""), "action_id": action.get("action_id", ""),
} }
action_counter += 1 )
action_index_map[f"action_index_{action_counter}"] = {
"task_id": task_id,
"action_id": action.get("action_id", ""),
}
action_counter += 1
if not input_actions: if not custom_field_actions:
LOG.warning("No input_text actions found in workflow run") LOG.warning("No field_name_actions found in workflow run")
return _generate_empty_schema(), {} return _generate_empty_schema(), {}
# Generate field names using LLM # Generate field names using LLM
try: try:
field_mapping = await _generate_field_names_with_llm(input_actions) field_mapping = await _generate_field_names_with_llm(custom_field_actions)
# Generate the Pydantic schema code # Generate the Pydantic schema code
schema_code = _generate_pydantic_schema(field_mapping.schema_fields) schema_code = _generate_pydantic_schema(field_mapping.schema_fields)
@@ -86,7 +98,7 @@ async def generate_workflow_parameters_schema(
return _generate_empty_schema(), {} return _generate_empty_schema(), {}
async def _generate_field_names_with_llm(input_actions: List[Dict[str, Any]]) -> GeneratedFieldMapping: async def _generate_field_names_with_llm(custom_field_actions: List[Dict[str, Any]]) -> GeneratedFieldMapping:
""" """
Use LLM to generate field names from input actions. Use LLM to generate field names from input actions.
@@ -96,7 +108,9 @@ async def _generate_field_names_with_llm(input_actions: List[Dict[str, Any]]) ->
Returns: Returns:
GeneratedFieldMapping with field mappings and schema definitions GeneratedFieldMapping with field mappings and schema definitions
""" """
prompt = prompt_engine.load_prompt(template="generate-workflow-parameters", input_actions=input_actions) prompt = prompt_engine.load_prompt(
template="generate-workflow-parameters", custom_field_actions=custom_field_actions
)
response = await app.LLM_API_HANDLER(prompt=prompt, prompt_name="generate-workflow-parameters") response = await app.LLM_API_HANDLER(prompt=prompt, prompt_name="generate-workflow-parameters")
@@ -166,22 +180,22 @@ def hydrate_input_text_actions_with_field_names(
for action in actions: for action in actions:
action_copy = action.copy() action_copy = action.copy()
if action.get("action_type") == ActionType.INPUT_TEXT: if action.get("action_type") in CUSTOM_FIELD_ACTIONS:
action_id = action.get("action_id", "") action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}" mapping_key = f"{task_id}:{action_id}"
if mapping_key in field_mappings: if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key] action_copy["field_name"] = field_mappings[mapping_key]
else: # else:
# Fallback field name if mapping not found # # Fallback field name if mapping not found
intention = action.get("intention", "") # intention = action.get("intention", "")
if intention: # if intention:
# Simple field name generation from intention # # Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "") # field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
field_name = "".join(c for c in field_name if c.isalnum() or c == "_") # field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
action_copy["field_name"] = field_name or "unknown_field" # action_copy["field_name"] = field_name or "unknown_field"
else: # else:
action_copy["field_name"] = "unknown_field" # action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy) updated_actions.append(action_copy)

View File

@@ -23,11 +23,15 @@ from skyvern.forge.sdk.core import skyvern_context
from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.actions import handler_utils from skyvern.webeye.actions import handler_utils
from skyvern.webeye.actions.action_types import ActionType from skyvern.webeye.actions.action_types import ActionType
from skyvern.webeye.actions.actions import Action, ActionStatus, ExtractAction, SelectOption from skyvern.webeye.actions.actions import Action, ActionStatus, ExtractAction, InputTextAction, SelectOption
from skyvern.webeye.actions.handler import handle_input_text_action, handle_select_option_action
from skyvern.webeye.actions.parse_actions import parse_actions
from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website
LOG = structlog.get_logger() LOG = structlog.get_logger()
SELECT_OPTION_GOAL = """- The intention to select an option: {intention}.
- The overall goal that the user wants to achieve: {prompt}."""
class Driver(StrEnum): class Driver(StrEnum):
@@ -52,6 +56,12 @@ class ActionCall:
error: Exception | None = None # populated if failed error: Exception | None = None # populated if failed
async def _get_element_id_by_xpath(xpath: str, page: Page) -> str | None:
locator = page.locator(f"xpath={xpath}")
element_id = await locator.get_attribute("unique_id")
return element_id
class SkyvernPage: class SkyvernPage:
""" """
A minimal adapter around the chosen driver that: A minimal adapter around the chosen driver that:
@@ -208,17 +218,20 @@ class SkyvernPage:
# Create action record. TODO: store more action fields # Create action record. TODO: store more action fields
kwargs = kwargs or {} kwargs = kwargs or {}
# we're using "value" instead of "text" for input text actions interface # we're using "value" instead of "text" for input text actions interface
text = kwargs.get("value", "") text = None
option_value = kwargs.get("option") select_option = None
select_option = SelectOption(value=option_value) if option_value else None
response: str | None = kwargs.get("response") response: str | None = kwargs.get("response")
file_url = kwargs.get("file_url")
if not response: if not response:
if action_type == ActionType.INPUT_TEXT: if action_type == ActionType.INPUT_TEXT:
text = str(call_result) text = str(call_result)
response = text response = text
elif action_type == ActionType.SELECT_OPTION: elif action_type == ActionType.SELECT_OPTION:
if select_option: option_value = str(call_result) or ""
response = select_option.value select_option = SelectOption(value=option_value)
response = option_value
elif action_type == ActionType.UPLOAD_FILE:
file_url = str(call_result)
action = Action( action = Action(
element_id="", element_id="",
@@ -234,6 +247,7 @@ class SkyvernPage:
reasoning=f"Auto-generated action for {action_type.value}", reasoning=f"Auto-generated action for {action_type.value}",
text=text, text=text,
option=select_option, option=select_option,
file_url=file_url,
response=response, response=response,
created_by="script", created_by="script",
) )
@@ -283,7 +297,8 @@ class SkyvernPage:
if screenshot: if screenshot:
# Create a minimal Step object for artifact creation # Create a minimal Step object for artifact creation
step = await app.DATABASE.get_step( step = await app.DATABASE.get_step(
context.task_id, context.step_id, organization_id=context.organization_id context.step_id,
organization_id=context.organization_id,
) )
if not step: if not step:
return return
@@ -415,17 +430,24 @@ class SkyvernPage:
context = skyvern_context.current() context = skyvern_context.current()
value = value or "" value = value or ""
transformed_value = value transformed_value = value
element_id: str | None = None
organization_id = context.organization_id if context else None
task_id = context.task_id if context else None
step_id = context.step_id if context else None
workflow_run_id = context.workflow_run_id if context else None
task = await app.DATABASE.get_task(task_id, organization_id) if task_id and organization_id else None
step = await app.DATABASE.get_step(step_id, organization_id) if step_id and organization_id else None
if ai_infer and intention: if ai_infer and intention:
try: try:
prompt = context.prompt if context else None prompt = context.prompt if context else None
# Build the element tree of the current page for the prompt # Build the element tree of the current page for the prompt
# clean up empty data values # clean up empty data values
data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "") data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "")
if (totp_identifier or totp_url) and context and context.organization_id and context.task_id: if (totp_identifier or totp_url) and context and organization_id and task_id:
verification_code = await poll_verification_code( verification_code = await poll_verification_code(
organization_id=context.organization_id, organization_id=organization_id,
task_id=context.task_id, task_id=task_id,
workflow_run_id=context.workflow_run_id, workflow_run_id=workflow_run_id,
totp_identifier=totp_identifier, totp_identifier=totp_identifier,
totp_verification_url=totp_url, totp_verification_url=totp_url,
) )
@@ -439,6 +461,10 @@ class SkyvernPage:
else: else:
data = {SPECIAL_FIELD_VERIFICATION_CODE: verification_code} data = {SPECIAL_FIELD_VERIFICATION_CODE: verification_code}
refreshed_page = await self.scraped_page.generate_scraped_page_without_screenshots()
self.scraped_page = refreshed_page
# get the element_id by the xpath
element_id = await _get_element_id_by_xpath(xpath, self.page)
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "") payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_input_text_prompt = prompt_engine.load_prompt( script_generation_input_text_prompt = prompt_engine.load_prompt(
template="script-generation-input-text-generatiion", template="script-generation-input-text-generatiion",
@@ -449,7 +475,7 @@ class SkyvernPage:
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER( json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
prompt=script_generation_input_text_prompt, prompt=script_generation_input_text_prompt,
prompt_name="script-generation-input-text-generatiion", prompt_name="script-generation-input-text-generatiion",
organization_id=context.organization_id if context else None, organization_id=organization_id,
) )
value = json_response.get("answer", value) value = json_response.get("answer", value)
except Exception: except Exception:
@@ -458,39 +484,119 @@ class SkyvernPage:
if context and context.workflow_run_id: if context and context.workflow_run_id:
transformed_value = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, value) transformed_value = await _get_actual_value_of_parameter_if_secret(context.workflow_run_id, value)
locator = self.page.locator(f"xpath={xpath}") if element_id and organization_id and task and step:
await handler_utils.input_sequentially(locator, transformed_value, timeout=timeout) action = InputTextAction(
element_id=element_id,
text=value,
status=ActionStatus.pending,
organization_id=organization_id,
workflow_run_id=workflow_run_id,
task_id=task_id,
step_id=context.step_id if context else None,
reasoning=intention,
intention=intention,
response=value,
)
await handle_input_text_action(action, self.page, self.scraped_page, task, step)
else:
locator = self.page.locator(f"xpath={xpath}")
await handler_utils.input_sequentially(locator, transformed_value, timeout=timeout)
return value return value
@action_wrap(ActionType.UPLOAD_FILE) @action_wrap(ActionType.UPLOAD_FILE)
async def upload_file( async def upload_file(
self, xpath: str, file_path: str, intention: str | None = None, data: str | dict[str, Any] | None = None self,
) -> None: xpath: str,
# if self.generate_response: files: str,
# # TODO: regenerate file_path and xpath ai_infer: bool = False,
# pass intention: str | None = None,
file = await download_file(file_path) data: str | dict[str, Any] | None = None,
await self.page.set_input_files(xpath, file) ) -> str:
if ai_infer and intention:
try:
context = skyvern_context.current()
prompt = context.prompt if context else None
data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "")
payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
script_generation_file_url_prompt = prompt_engine.load_prompt(
template="script-generation-file-url-generation",
intention=intention,
data=payload_str,
goal=prompt,
)
json_response = await app.SINGLE_INPUT_AGENT_LLM_API_HANDLER(
prompt=script_generation_file_url_prompt,
prompt_name="script-generation-file-url-generation",
organization_id=context.organization_id if context else None,
)
files = json_response.get("answer", files)
except Exception:
LOG.exception(f"Failed to adapt value for input text action on xpath={xpath}, file={files}")
file_path = await download_file(files)
locator = self.page.locator(f"xpath={xpath}")
await locator.set_input_files(file_path)
return files
@action_wrap(ActionType.SELECT_OPTION) @action_wrap(ActionType.SELECT_OPTION)
async def select_option( async def select_option(
self, self,
xpath: str, xpath: str,
option: str, value: str,
ai_infer: bool = False,
intention: str | None = None, intention: str | None = None,
data: str | dict[str, Any] | None = None, data: str | dict[str, Any] | None = None,
timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS, timeout: float = settings.BROWSER_ACTION_TIMEOUT_MS,
) -> None: ) -> str:
# if self.generate_response: option_value = value or ""
# # TODO: regenerate option context = skyvern_context.current()
# pass if context and context.task_id and context.step_id and context.organization_id:
locator = self.page.locator(f"xpath={xpath}") task = await app.DATABASE.get_task(context.task_id, organization_id=context.organization_id)
try: step = await app.DATABASE.get_step(context.step_id, organization_id=context.organization_id)
await locator.click(timeout=timeout) if ai_infer and intention and task and step:
except Exception: try:
print("Failed to click before select action") prompt = context.prompt if context else None
return data = {k: v for k, v in data.items() if v} if isinstance(data, dict) else (data or "")
await locator.select_option(option, timeout=timeout) payload_str = json.dumps(data) if isinstance(data, (dict, list)) else (data or "")
refreshed_page = await self.scraped_page.generate_scraped_page_without_screenshots()
self.scraped_page = refreshed_page
element_tree = refreshed_page.build_element_tree()
merged_goal = SELECT_OPTION_GOAL.format(intention=intention, prompt=prompt)
single_select_prompt = prompt_engine.load_prompt(
template="single-select-action",
navigation_payload_str=payload_str,
navigation_goal=merged_goal,
current_url=self.page.url,
elements=element_tree,
local_datetime=datetime.now(context.tz_info or datetime.now().astimezone().tzinfo).isoformat(),
)
json_response = await app.SELECT_AGENT_LLM_API_HANDLER(
prompt=single_select_prompt,
prompt_name="single-select-action",
organization_id=context.organization_id if context else None,
)
actions = parse_actions(task, step.step_id, step.order, self.scraped_page, json_response["actions"])
if actions:
action = actions[0]
if not action.option:
raise ValueError("SelectOptionAction requires an 'option' field")
option_value = action.option.value or action.option.label or ""
await handle_select_option_action(
action=action,
page=self.page,
scraped_page=self.scraped_page,
task=task,
step=step,
)
else:
LOG.exception(
f"Failed to parse actions for select option action on xpath={xpath}, value={value}"
)
except Exception:
LOG.exception(f"Failed to adapt value for select option action on xpath={xpath}, value={value}")
else:
locator = self.page.locator(f"xpath={xpath}")
await locator.select_option(option_value, timeout=timeout)
return option_value
@action_wrap(ActionType.WAIT) @action_wrap(ActionType.WAIT)
async def wait( async def wait(
@@ -556,7 +662,8 @@ class SkyvernPage:
step = None step = None
if context and context.organization_id and context.task_id and context.step_id: if context and context.organization_id and context.task_id and context.step_id:
step = await app.DATABASE.get_step( step = await app.DATABASE.get_step(
task_id=context.task_id, step_id=context.step_id, organization_id=context.organization_id step_id=context.step_id,
organization_id=context.organization_id,
) )
result = await app.EXTRACTION_LLM_API_HANDLER( result = await app.EXTRACTION_LLM_API_HANDLER(

View File

@@ -80,9 +80,8 @@ from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock,
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
from skyvern.schemas.runs import CUA_ENGINES, RunEngine from skyvern.schemas.runs import CUA_ENGINES, RunEngine
from skyvern.schemas.steps import AgentStepOutput from skyvern.schemas.steps import AgentStepOutput
from skyvern.services import run_service from skyvern.services import run_service, service_utils
from skyvern.services.action_service import get_action_history from skyvern.services.action_service import get_action_history
from skyvern.services.task_v1_service import is_cua_task
from skyvern.utils.image_resizer import Resolution from skyvern.utils.image_resizer import Resolution
from skyvern.utils.prompt_engine import MaxStepsReasonResponse, load_prompt_with_elements from skyvern.utils.prompt_engine import MaxStepsReasonResponse, load_prompt_with_elements
from skyvern.webeye.actions.action_types import ActionType from skyvern.webeye.actions.action_types import ActionType
@@ -1669,7 +1668,7 @@ class ForgeAgent:
) )
scroll = True scroll = True
llm_key_override = task.llm_key llm_key_override = task.llm_key
if await is_cua_task(task=task): if await service_utils.is_cua_task(task=task):
scroll = False scroll = False
llm_key_override = None llm_key_override = None
@@ -2709,7 +2708,7 @@ class ForgeAgent:
steps_results.append(step_result) steps_results.append(step_result)
scroll = True scroll = True
if await is_cua_task(task=task): if await service_utils.is_cua_task(task=task):
scroll = False scroll = False
screenshots: list[bytes] = [] screenshots: list[bytes] = []
@@ -2971,7 +2970,7 @@ class ForgeAgent:
verification_code_check=False, verification_code_check=False,
) )
llm_key_override = task.llm_key llm_key_override = task.llm_key
if await is_cua_task(task=task): if await service_utils.is_cua_task(task=task):
llm_key_override = None llm_key_override = None
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
llm_key_override, default=app.LLM_API_HANDLER llm_key_override, default=app.LLM_API_HANDLER

View File

@@ -1,18 +1,19 @@
You are an expert at analyzing user interface automation actions and generating meaningful field names for data structures. You are an expert at analyzing user interface automation actions and generating meaningful field names for data structures.
Given a list of input_text actions with their intentions and text values, generate appropriate field names for a Pydantic BaseModel class called "GeneratedWorkflowParameters". Given a list of input_text, upload_file and select_option actions with their intentions and values, generate appropriate field names for a Pydantic BaseModel class called "GeneratedWorkflowParameters".
## Rules: ## Rules:
1. Field names should be valid Python identifiers (snake_case, no spaces, no special characters except underscore) 1. Field names should be valid Python identifiers (snake_case, no spaces, no special characters except underscore)
2. Field names should be descriptive and based on the intention of the action 2. Field names should be descriptive and based on the intention of the action
3. If multiple actions input the same text value, they should map to the same field name 3. If multiple actions use the same text value, they should map to the same field name
4. Field names should be concise but clear about what data they represent 4. Field names should be concise but clear about what data they represent
5. Avoid generic names like "field1", "input1" - use meaningful names based on the intention 5. Avoid generic names like "field1", "input1" - use meaningful names based on the intention
## Input Actions: ## Actions:
{% for action in input_actions %} {% for action in custom_field_actions %}
Action {{ loop.index }}: Action {{ loop.index }}:
- Text: "{{ action.text }}" - Action type: "{{ action.action_type }}"
- Value: "{{ action.value }}"
- Intention: "{{ action.intention }}" - Intention: "{{ action.intention }}"
{% endfor %} {% endfor %}

View File

@@ -0,0 +1,17 @@
# Goal
You are an expert in uploading files on a webpage. Help the user figure out the specific file url to use to upload a file.
# Provided information:{% if goal %}
- User's overall goal: {{ goal }}{% endif %}
- Context and details: {{ data }}
- The question or the intention for this file upload action: {{ intention }}
# Output
- Your answer should be a valid url to a file.
- YOUR RESPONSE HAS TO BE IN JSON FORMAT. DO NOT RETURN ANYTHING ELSE.
- DO NOT INCLUDE ANY UNRELATED INFORMATION OR UNNECESSARY DETAILS IN YOUR ANSWER.
EXAMPLE RESPONSE FORMAT:
{
"answer": "string",
}

View File

@@ -326,7 +326,7 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True) LOG.error("UnexpectedError", exc_info=True)
raise raise
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None: async def get_step(self, step_id: str, organization_id: str | None = None) -> Step | None:
try: try:
async with self.Session() as session: async with self.Session() as session:
if step := ( if step := (
@@ -588,7 +588,7 @@ class AgentDB:
step.cached_token_count = incremental_cached_tokens + (step.cached_token_count or 0) step.cached_token_count = incremental_cached_tokens + (step.cached_token_count or 0)
await session.commit() await session.commit()
updated_step = await self.get_step(task_id, step_id, organization_id) updated_step = await self.get_step(step_id, organization_id)
if not updated_step: if not updated_step:
raise NotFoundError("Step not found") raise NotFoundError("Step not found")
return updated_step return updated_step

View File

@@ -0,0 +1,28 @@
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES
async def is_cua_task(
*,
task: Task,
) -> bool:
"""Return True if the run, engine, or task indicates a CUA task."""
if task.workflow_run_id:
# it's a task based block, should look up the block run to see if it's a CUA task
block = await app.DATABASE.get_workflow_run_block_by_task_id(
task_id=task.task_id,
organization_id=task.organization_id,
)
if block.engine is not None and block.engine in CUA_ENGINES:
return True
run = await app.DATABASE.get_run(
run_id=task.task_id,
organization_id=task.organization_id,
)
if run and run.task_run_type in CUA_RUN_TYPES:
return True
return False

View File

@@ -14,7 +14,7 @@ from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus
from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine, RunType from skyvern.schemas.runs import RunEngine, RunType
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -150,28 +150,3 @@ async def get_task_v1_response(task_id: str, organization_id: str | None = None)
return await app.agent.build_task_response( return await app.agent.build_task_response(
task=task_obj, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True task=task_obj, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True
) )
async def is_cua_task(
*,
task: Task,
) -> bool:
"""Return True if the run, engine, or task indicates a CUA task."""
if task.workflow_run_id:
# it's a task based block, should look up the block run to see if it's a CUA task
block = await app.DATABASE.get_workflow_run_block_by_task_id(
task_id=task.task_id,
organization_id=task.organization_id,
)
if block.engine is not None and block.engine in CUA_ENGINES:
return True
run = await app.DATABASE.get_run(
run_id=task.task_id,
organization_id=task.organization_id,
)
if run and run.task_run_type in CUA_RUN_TYPES:
return True
return False

View File

@@ -69,8 +69,8 @@ from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.services.bitwarden import BitwardenConstants from skyvern.forge.sdk.services.bitwarden import BitwardenConstants
from skyvern.forge.sdk.services.credentials import AzureVaultConstants, OnePasswordConstants from skyvern.forge.sdk.services.credentials import AzureVaultConstants, OnePasswordConstants
from skyvern.forge.sdk.trace import TraceManager from skyvern.forge.sdk.trace import TraceManager
from skyvern.services import service_utils
from skyvern.services.action_service import get_action_history from skyvern.services.action_service import get_action_history
from skyvern.services.task_v1_service import is_cua_task
from skyvern.utils.prompt_engine import ( from skyvern.utils.prompt_engine import (
CheckDateFormatResponse, CheckDateFormatResponse,
CheckPhoneNumberFormatResponse, CheckPhoneNumberFormatResponse,
@@ -3599,7 +3599,7 @@ async def extract_information_for_navigation_goal(
) )
llm_key_override = task.llm_key llm_key_override = task.llm_key
if await is_cua_task(task=task): if await service_utils.is_cua_task(task=task):
# CUA tasks should use the default data extraction llm key # CUA tasks should use the default data extraction llm key
llm_key_override = None llm_key_override = None