diff --git a/alembic/versions/2025_06_17_0723-2be3e0ba85ff_add_engine_to_workflow_run_block.py b/alembic/versions/2025_06_17_0723-2be3e0ba85ff_add_engine_to_workflow_run_block.py new file mode 100644 index 00000000..cbc8b4e3 --- /dev/null +++ b/alembic/versions/2025_06_17_0723-2be3e0ba85ff_add_engine_to_workflow_run_block.py @@ -0,0 +1,33 @@ +"""add_engine_to_workflow_run_block + +Revision ID: 2be3e0ba85ff +Revises: 2c6b27e8e961 +Create Date: 2025-06-17 07:23:13.753617+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "2be3e0ba85ff" +down_revision: Union[str, None] = "2c6b27e8e961" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("workflow_run_blocks", sa.Column("engine", sa.String(), nullable=True)) + op.create_index(op.f("ix_workflow_run_blocks_task_id"), "workflow_run_blocks", ["task_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_workflow_run_blocks_task_id"), table_name="workflow_run_blocks") + op.drop_column("workflow_run_blocks", "engine") + # ### end Alembic commands ### diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index cb8ce0ef..0344cd05 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -73,8 +73,9 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus -from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine +from skyvern.schemas.runs import CUA_ENGINES, RunEngine from skyvern.services import run_service +from skyvern.services.task_v1_service import is_cua_task from skyvern.utils.image_resizer import Resolution from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.webeye.actions.action_types import ActionType @@ -268,6 +269,12 @@ class ForgeAgent: cua_response: OpenAIResponse | None = None, llm_caller: LLMCaller | None = None, ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: + # do not need to do complete verification when it's a CUA task + # 1. CUA executes only one action step by step -- it's pretty less likely to have a hallucination for completion or forget to return a complete + # 2. It will significantly slow down CUA tasks + if engine in CUA_ENGINES: + complete_verification = False + workflow_run: WorkflowRun | None = None if task.workflow_run_id: workflow_run = await app.DATABASE.get_workflow_run( @@ -1575,10 +1582,9 @@ class ForgeAgent: step_id=step.step_id, workflow_run_id=task.workflow_run_id, ) - run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) scroll = True llm_key_override = task.llm_key - if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: + if await is_cua_task(task=task): scroll = False llm_key_override = None @@ -2628,9 +2634,8 @@ class ForgeAgent: step_result["actions_result"] = action_result_summary steps_results.append(step_result) - run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) scroll = True - if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: + if await is_cua_task(task=task): scroll = False screenshots: list[bytes] = [] @@ -2880,8 +2885,7 @@ class ForgeAgent: expire_verification_code=True, ) llm_key_override = task.llm_key - run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) - if run_obj and run_obj.task_run_type in CUA_RUN_TYPES: + if await is_cua_task(task=task): llm_key_override = None llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( llm_key_override, default=app.LLM_API_HANDLER diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index c384bc51..6e32c262 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -95,7 +95,7 @@ from skyvern.forge.sdk.workflow.models.workflow import ( WorkflowRunStatus, WorkflowStatus, ) -from skyvern.schemas.runs import ProxyLocation, RunType +from skyvern.schemas.runs import ProxyLocation, RunEngine, RunType from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.models import AgentStepOutput @@ -2707,6 +2707,7 @@ class AgentDB: status: BlockStatus = BlockStatus.running, output: dict | list | str | None = None, continue_on_failure: bool = False, + engine: RunEngine | None = None, ) -> WorkflowRunBlock: async with self.Session() as session: new_workflow_run_block = WorkflowRunBlockModel( @@ -2719,6 +2720,7 @@ class AgentDB: status=status, output=output, continue_on_failure=continue_on_failure, + engine=engine, ) session.add(new_workflow_run_block) await session.commit() @@ -2759,6 +2761,7 @@ class AgentDB: wait_sec: int | None = None, description: str | None = None, block_workflow_run_id: str | None = None, + engine: str | None = None, ) -> WorkflowRunBlock: async with self.Session() as session: workflow_run_block = ( @@ -2799,6 +2802,8 @@ class AgentDB: workflow_run_block.description = description if block_workflow_run_id: workflow_run_block.block_workflow_run_id = block_workflow_run_id + if engine: + workflow_run_block.engine = engine await session.commit() await session.refresh(workflow_run_block) else: @@ -2830,6 +2835,25 @@ class AgentDB: return convert_to_workflow_run_block(workflow_run_block, task=task) raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found") + async def get_workflow_run_block_by_task_id( + self, + task_id: str, + organization_id: str | None = None, + ) -> WorkflowRunBlock: + async with self.Session() as session: + workflow_run_block = ( + await session.scalars( + select(WorkflowRunBlockModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id) + ) + ).first() + if workflow_run_block: + task = None + task_id = workflow_run_block.task_id + if task_id: + task = await self.get_task(task_id, organization_id=organization_id) + return convert_to_workflow_run_block(workflow_run_block, task=task) + raise NotFoundError(f"WorkflowRunBlock not found by {task_id}") + async def get_workflow_run_blocks( self, workflow_run_id: str, diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index b41c1b42..67dc5bc2 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -573,13 +573,14 @@ class WorkflowRunBlockModel(Base): parent_workflow_run_block_id = Column(String, nullable=True) organization_id = Column(String, nullable=True) description = Column(String, nullable=True) - task_id = Column(String, nullable=True) + task_id = Column(String, index=True, nullable=True) label = Column(String, nullable=True) block_type = Column(String, nullable=False) status = Column(String, nullable=False) output = Column(JSON, nullable=True) continue_on_failure = Column(Boolean, nullable=False, default=False) failure_reason = Column(String, nullable=True) + engine = Column(String, nullable=True) # for loop block loop_values = Column(JSON, nullable=True) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 441a07e3..1d51d008 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -460,6 +460,7 @@ def convert_to_workflow_run_block( output=workflow_run_block_model.output, continue_on_failure=workflow_run_block_model.continue_on_failure, failure_reason=workflow_run_block_model.failure_reason, + engine=workflow_run_block_model.engine, task_id=workflow_run_block_model.task_id, loop_values=workflow_run_block_model.loop_values, current_value=workflow_run_block_model.current_value, diff --git a/skyvern/forge/sdk/schemas/workflow_runs.py b/skyvern/forge/sdk/schemas/workflow_runs.py index 4ba1ed2c..6574ffdd 100644 --- a/skyvern/forge/sdk/schemas/workflow_runs.py +++ b/skyvern/forge/sdk/schemas/workflow_runs.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from skyvern.forge.sdk.schemas.task_v2 import Thought from skyvern.forge.sdk.workflow.models.block import BlockType +from skyvern.schemas.runs import RunEngine from skyvern.webeye.actions.actions import Action @@ -24,6 +25,7 @@ class WorkflowRunBlock(BaseModel): output: dict | list | str | None = None continue_on_failure: bool = False failure_reason: str | None = None + engine: RunEngine | None = None task_id: str | None = None url: str | None = None navigation_goal: str | None = None diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index f3684a61..6f1e9f43 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -288,7 +288,11 @@ class Block(BaseModel, abc.ABC): **kwargs: dict, ) -> BlockResult: workflow_run_block_id = None + engine: RunEngine | None = None try: + if isinstance(self, BaseTaskBlock): + engine = self.engine + workflow_run_block = await app.DATABASE.create_workflow_run_block( workflow_run_id=workflow_run_id, organization_id=organization_id, @@ -296,6 +300,7 @@ class Block(BaseModel, abc.ABC): label=self.label, block_type=self.block_type, continue_on_failure=self.continue_on_failure, + engine=engine, ) workflow_run_block_id = workflow_run_block.workflow_run_block_id diff --git a/skyvern/services/task_v1_service.py b/skyvern/services/task_v1_service.py index fe964699..b975211d 100644 --- a/skyvern/services/task_v1_service.py +++ b/skyvern/services/task_v1_service.py @@ -14,7 +14,7 @@ from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus -from skyvern.schemas.runs import RunEngine, RunType +from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine, RunType LOG = structlog.get_logger() @@ -148,3 +148,28 @@ async def get_task_v1_response(task_id: str, organization_id: str | None = None) return await app.agent.build_task_response( task=task_obj, last_step=latest_step, failure_reason=failure_reason, need_browser_log=True ) + + +async def is_cua_task( + *, + task: Task, +) -> bool: + """Return True if the run, engine, or task indicates a CUA task.""" + + if task.workflow_run_id: + # it's a task based block, should look up the block run to see if it's a CUA task + block = await app.DATABASE.get_workflow_run_block_by_task_id( + task_id=task.task_id, + organization_id=task.organization_id, + ) + if block.engine is not None and block.engine in CUA_ENGINES: + return True + + run = await app.DATABASE.get_run( + run_id=task.task_id, + organization_id=task.organization_id, + ) + if run and run.task_run_type in CUA_RUN_TYPES: + return True + + return False diff --git a/skyvern/webeye/actions/handler.py b/skyvern/webeye/actions/handler.py index 484b404b..4bb1fdee 100644 --- a/skyvern/webeye/actions/handler.py +++ b/skyvern/webeye/actions/handler.py @@ -71,7 +71,7 @@ from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.tasks import Task from skyvern.forge.sdk.services.bitwarden import BitwardenConstants from skyvern.forge.sdk.services.credentials import OnePasswordConstants -from skyvern.schemas.runs import CUA_RUN_TYPES +from skyvern.services.task_v1_service import is_cua_task from skyvern.utils.prompt_engine import CheckPhoneNumberFormatResponse, load_prompt_with_elements from skyvern.webeye.actions import actions, handler_utils from skyvern.webeye.actions.action_types import ActionType @@ -3377,9 +3377,8 @@ async def extract_information_for_navigation_goal( local_datetime=datetime.now(context.tz_info).isoformat(), ) - task_run = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) llm_key_override = task.llm_key - if task_run and task_run.task_run_type in CUA_RUN_TYPES: + if await is_cua_task(task=task): # CUA tasks should use the default data extraction llm key llm_key_override = None