disable complete verification when CUA engine (#2728)

Co-authored-by: lawyzheng <lawyzheng1106@gmail.com>
This commit is contained in:
Shuchang Zheng
2025-06-17 00:25:58 -07:00
committed by GitHub
parent b241185aae
commit f1bc1a03db
9 changed files with 107 additions and 13 deletions

View File

@@ -73,8 +73,9 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine
from skyvern.schemas.runs import CUA_ENGINES, RunEngine
from skyvern.services import run_service
from skyvern.services.task_v1_service import is_cua_task
from skyvern.utils.image_resizer import Resolution
from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.actions.action_types import ActionType
@@ -268,6 +269,12 @@ class ForgeAgent:
cua_response: OpenAIResponse | None = None,
llm_caller: LLMCaller | None = None,
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
# do not need to do complete verification when it's a CUA task
# 1. CUA executes only one action step by step -- it's pretty less likely to have a hallucination for completion or forget to return a complete
# 2. It will significantly slow down CUA tasks
if engine in CUA_ENGINES:
complete_verification = False
workflow_run: WorkflowRun | None = None
if task.workflow_run_id:
workflow_run = await app.DATABASE.get_workflow_run(
@@ -1575,10 +1582,9 @@ class ForgeAgent:
step_id=step.step_id,
workflow_run_id=task.workflow_run_id,
)
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True
llm_key_override = task.llm_key
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
if await is_cua_task(task=task):
scroll = False
llm_key_override = None
@@ -2628,9 +2634,8 @@ class ForgeAgent:
step_result["actions_result"] = action_result_summary
steps_results.append(step_result)
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
if await is_cua_task(task=task):
scroll = False
screenshots: list[bytes] = []
@@ -2880,8 +2885,7 @@ class ForgeAgent:
expire_verification_code=True,
)
llm_key_override = task.llm_key
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
if await is_cua_task(task=task):
llm_key_override = None
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
llm_key_override, default=app.LLM_API_HANDLER

View File

@@ -95,7 +95,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus,
WorkflowStatus,
)
from skyvern.schemas.runs import ProxyLocation, RunType
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunType
from skyvern.webeye.actions.actions import Action
from skyvern.webeye.actions.models import AgentStepOutput
@@ -2707,6 +2707,7 @@ class AgentDB:
status: BlockStatus = BlockStatus.running,
output: dict | list | str | None = None,
continue_on_failure: bool = False,
engine: RunEngine | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
new_workflow_run_block = WorkflowRunBlockModel(
@@ -2719,6 +2720,7 @@ class AgentDB:
status=status,
output=output,
continue_on_failure=continue_on_failure,
engine=engine,
)
session.add(new_workflow_run_block)
await session.commit()
@@ -2759,6 +2761,7 @@ class AgentDB:
wait_sec: int | None = None,
description: str | None = None,
block_workflow_run_id: str | None = None,
engine: str | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
workflow_run_block = (
@@ -2799,6 +2802,8 @@ class AgentDB:
workflow_run_block.description = description
if block_workflow_run_id:
workflow_run_block.block_workflow_run_id = block_workflow_run_id
if engine:
workflow_run_block.engine = engine
await session.commit()
await session.refresh(workflow_run_block)
else:
@@ -2830,6 +2835,25 @@ class AgentDB:
return convert_to_workflow_run_block(workflow_run_block, task=task)
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
async def get_workflow_run_block_by_task_id(
self,
task_id: str,
organization_id: str | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
workflow_run_block = (
await session.scalars(
select(WorkflowRunBlockModel).filter_by(task_id=task_id).filter_by(organization_id=organization_id)
)
).first()
if workflow_run_block:
task = None
task_id = workflow_run_block.task_id
if task_id:
task = await self.get_task(task_id, organization_id=organization_id)
return convert_to_workflow_run_block(workflow_run_block, task=task)
raise NotFoundError(f"WorkflowRunBlock not found by {task_id}")
async def get_workflow_run_blocks(
self,
workflow_run_id: str,

View File

@@ -573,13 +573,14 @@ class WorkflowRunBlockModel(Base):
parent_workflow_run_block_id = Column(String, nullable=True)
organization_id = Column(String, nullable=True)
description = Column(String, nullable=True)
task_id = Column(String, nullable=True)
task_id = Column(String, index=True, nullable=True)
label = Column(String, nullable=True)
block_type = Column(String, nullable=False)
status = Column(String, nullable=False)
output = Column(JSON, nullable=True)
continue_on_failure = Column(Boolean, nullable=False, default=False)
failure_reason = Column(String, nullable=True)
engine = Column(String, nullable=True)
# for loop block
loop_values = Column(JSON, nullable=True)

View File

@@ -460,6 +460,7 @@ def convert_to_workflow_run_block(
output=workflow_run_block_model.output,
continue_on_failure=workflow_run_block_model.continue_on_failure,
failure_reason=workflow_run_block_model.failure_reason,
engine=workflow_run_block_model.engine,
task_id=workflow_run_block_model.task_id,
loop_values=workflow_run_block_model.loop_values,
current_value=workflow_run_block_model.current_value,

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel
from skyvern.forge.sdk.schemas.task_v2 import Thought
from skyvern.forge.sdk.workflow.models.block import BlockType
from skyvern.schemas.runs import RunEngine
from skyvern.webeye.actions.actions import Action
@@ -24,6 +25,7 @@ class WorkflowRunBlock(BaseModel):
output: dict | list | str | None = None
continue_on_failure: bool = False
failure_reason: str | None = None
engine: RunEngine | None = None
task_id: str | None = None
url: str | None = None
navigation_goal: str | None = None

View File

@@ -288,7 +288,11 @@ class Block(BaseModel, abc.ABC):
**kwargs: dict,
) -> BlockResult:
workflow_run_block_id = None
engine: RunEngine | None = None
try:
if isinstance(self, BaseTaskBlock):
engine = self.engine
workflow_run_block = await app.DATABASE.create_workflow_run_block(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
@@ -296,6 +300,7 @@ class Block(BaseModel, abc.ABC):
label=self.label,
block_type=self.block_type,
continue_on_failure=self.continue_on_failure,
engine=engine,
)
workflow_run_block_id = workflow_run_block.workflow_run_block_id