From db68d8a60c00a79b5d984dfc1f3a19fc5e048ee7 Mon Sep 17 00:00:00 2001 From: pedrohsdb Date: Wed, 19 Nov 2025 17:34:08 -0800 Subject: [PATCH] scope termination-aware verification to file download fallback (#4043) --- skyvern/forge/agent.py | 258 +++++++++++++++++++++++++++++++++++------ 1 file changed, 220 insertions(+), 38 deletions(-) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 50e577d7..2fec70bb 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -501,6 +501,7 @@ class ForgeAgent: task = await self.update_task_errors_from_detailed_output(task, detailed_output) # type: ignore retry = False + download_detected = False if task_block and task_block.complete_on_download and task.workflow_run_id: workflow_download_directory = get_path_for_workflow_download_directory( context.run_id if context and context.run_id else task.workflow_run_id @@ -538,6 +539,7 @@ class ForgeAgent: ) list_files_after = list_files_after + browser_session_downloaded_files_after if len(list_files_after) > len(list_files_before): + download_detected = True files_to_rename = list(set(list_files_after) - set(list_files_before)) for file in files_to_rename: if file.startswith("s3://"): @@ -599,6 +601,28 @@ class ForgeAgent: ) return last_step, detailed_output, None + if ( + task_block + and isinstance(task_block, FileDownloadBlock) + and task_block.complete_on_download + and task.workflow_run_id + and not download_detected + ): + handled, fallback_last_step = await self._handle_file_download_verification_fallback( + organization=organization, + task=task, + step=step, + browser_state=browser_state, + task_block=task_block, + detailed_output=detailed_output, + engine=engine, + api_key=api_key, + close_browser_on_completion=close_browser_on_completion, + browser_session_id=browser_session_id, + ) + if handled and fallback_last_step: + return fallback_last_step, detailed_output, None + # If the step failed, mark the step as failed and retry if step.status == StepStatus.failed: maybe_next_step = await self.handle_failed_step(organization, task, step) @@ -893,6 +917,172 @@ class ForgeAgent: @TraceManager.traced_async( ignore_inputs=["browser_state", "organization", "task_block", "cua_response", "llm_caller"] ) + async def _handle_file_download_verification_fallback( + self, + *, + organization: Organization, + task: Task, + step: Step, + browser_state: BrowserState, + task_block: FileDownloadBlock, + detailed_output: DetailedAgentStepOutput | None, + engine: RunEngine, + api_key: str | None, + close_browser_on_completion: bool, + browser_session_id: str | None, + ) -> tuple[bool, Step | None]: + if detailed_output is None or detailed_output.scraped_page is None: + return False, None + + try: + distinct_id = task.workflow_run_id if task.workflow_run_id else task.task_id + use_termination_prompt = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached( + "USE_TERMINATION_AWARE_COMPLETE_VERIFICATION", + distinct_id, + properties={"organization_id": task.organization_id}, + ) + except Exception as error: # pragma: no cover - defensive logging + LOG.warning( + "Failed to check USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment; skipping download fallback verification", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + error=str(error), + ) + return False, None + + if not use_termination_prompt: + return False, None + + try: + page = await browser_state.get_working_page() + if page is None: + page = await browser_state.must_get_working_page() + except Exception: + LOG.warning( + "File download fallback verification could not fetch working page, skipping verification", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + exc_info=True, + ) + return False, None + + try: + fallback_action = await self.check_user_goal_complete( + page=page, + scraped_page=detailed_output.scraped_page, + task=task, + step=step, + task_block=task_block, + ) + except Exception: + LOG.warning( + "File download fallback verification failed, continuing with standard flow", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + exc_info=True, + ) + return False, None + + if fallback_action is None: + LOG.info( + "File download fallback verification completed with continue status", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + ) + return False, None + + LOG.info( + "File download fallback verification returned decisive action", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + action_type=fallback_action.action_type if isinstance(fallback_action, Action) else "unknown", + ) + + if step.output is None: + step.output = AgentStepOutput(action_results=[], actions_and_results=[], errors=[]) + if step.output.action_results is None: + step.output.action_results = [] + if step.output.actions_and_results is None: + step.output.actions_and_results = [] + if step.output.errors is None: + step.output.errors = [] + if detailed_output.actions_and_results is None: + detailed_output.actions_and_results = [] + + persisted_action = cast(Action, fallback_action) + if isinstance(persisted_action, (CompleteAction, TerminateAction)): + persisted_action.organization_id = task.organization_id + persisted_action.workflow_run_id = task.workflow_run_id + persisted_action.task_id = task.task_id + persisted_action.step_id = step.step_id + persisted_action.step_order = step.order + persisted_action.action_order = len(step.output.actions_and_results) + + action_results = await ActionHandler.handle_action( + detailed_output.scraped_page, + task, + step, + page, + persisted_action, + ) + await self.record_artifacts_after_action(task, step, browser_state, engine) + + step.output.action_results.extend(action_results) + step.output.actions_and_results.append((persisted_action, action_results)) + detailed_output.actions_and_results.append((persisted_action, action_results)) + if isinstance(persisted_action, DecisiveAction) and persisted_action.errors: + step.output.errors.extend(persisted_action.errors) + + if isinstance(persisted_action, TerminateAction): + LOG.warning( + "File download fallback verification determined workflow should terminate", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + reasoning=persisted_action.reasoning, + ) + last_step = await self.update_step(step, output=step.output, is_last=True) + task_errors = None + if persisted_action.errors: + task_errors = [error.model_dump() for error in persisted_action.errors] + failure_reason = persisted_action.reasoning + if persisted_action.errors: + failure_reason = "; ".join(error.reasoning for error in persisted_action.errors) + updated_task = await self.update_task( + task, + status=TaskStatus.terminated, + failure_reason=failure_reason, + errors=task_errors, + ) + await self.clean_up_task( + task=updated_task, + last_step=last_step, + api_key=api_key, + close_browser_on_completion=close_browser_on_completion, + browser_session_id=browser_session_id, + ) + return True, last_step + + LOG.info( + "File download fallback verification marked task as complete", + task_id=task.task_id, + workflow_run_id=task.workflow_run_id, + ) + last_step = await self.update_step(step, output=step.output, is_last=True) + extracted_information = await self.get_extracted_information_for_task(task) + updated_task = await self.update_task( + task, + status=TaskStatus.completed, + extracted_information=extracted_information, + ) + await self.clean_up_task( + task=updated_task, + last_step=last_step, + api_key=api_key, + close_browser_on_completion=close_browser_on_completion, + browser_session_id=browser_session_id, + ) + return True, last_step + async def agent_step( self, task: Task, @@ -2024,7 +2214,14 @@ class ForgeAgent: ) async def complete_verify( - self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step, task_block: BaseTaskBlock | None = None + self, + page: Page, + scraped_page: ScrapedPage, + task: Task, + step: Step, + task_block: BaseTaskBlock | None = None, + *, + use_termination_prompt: bool = False, ) -> CompleteVerifyResult: LOG.info( "Checking if user goal is achieved after re-scraping the page", @@ -2042,35 +2239,6 @@ class ForgeAgent: if task.include_action_history_in_verification: actions_and_results_str = await self._get_action_results(task, current_step=step) - # Check if we should use the termination-aware prompt (experiment) - # Only enabled for file download blocks - use_termination_prompt = False - is_file_download_block = task_block is not None and isinstance(task_block, FileDownloadBlock) - - if is_file_download_block: - try: - distinct_id = task.workflow_run_id if task.workflow_run_id else task.task_id - use_termination_prompt = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached( - "USE_TERMINATION_AWARE_COMPLETE_VERIFICATION", - distinct_id, - properties={"organization_id": task.organization_id}, - ) - if use_termination_prompt: - LOG.info( - "Experiment enabled: using termination-aware complete verification prompt for file download block", - task_id=task.task_id, - workflow_run_id=task.workflow_run_id, - organization_id=task.organization_id, - block_type="file_download", - ) - except Exception as e: - LOG.warning( - "Failed to check USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment; using legacy behavior", - task_id=task.task_id, - workflow_run_id=task.workflow_run_id, - error=str(e), - ) - # Select the appropriate template based on experiment template_name = "check-user-goal-with-termination" if use_termination_prompt else "check-user-goal" prompt_name = "check-user-goal-with-termination" if use_termination_prompt else "check-user-goal" @@ -2133,7 +2301,14 @@ class ForgeAgent: return CompleteVerifyResult.model_validate(verification_result) async def check_user_goal_complete( - self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step, task_block: BaseTaskBlock | None = None + self, + page: Page, + scraped_page: ScrapedPage, + task: Task, + step: Step, + task_block: BaseTaskBlock | None = None, + *, + use_termination_prompt: bool = False, ) -> CompleteAction | TerminateAction | None: try: verification_result = await self.complete_verify( @@ -2142,17 +2317,24 @@ class ForgeAgent: task=task, step=step, task_block=task_block, + use_termination_prompt=use_termination_prompt, ) # Check if we should terminate instead of complete - # Note: This requires the USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment to be enabled if verification_result.is_terminate: - LOG.warning( - "Periodic verification determined task should terminate (termination-aware experiment)", - workflow_run_id=task.workflow_run_id, - thoughts=verification_result.thoughts, - status=verification_result.status if verification_result.status else "legacy", - ) + if use_termination_prompt: + LOG.warning( + "Periodic verification determined task should terminate (termination-aware experiment)", + workflow_run_id=task.workflow_run_id, + thoughts=verification_result.thoughts, + status=verification_result.status if verification_result.status else "legacy", + ) + else: + LOG.warning( + "Periodic verification determined task should terminate", + workflow_run_id=task.workflow_run_id, + thoughts=verification_result.thoughts, + ) return TerminateAction( reasoning=verification_result.thoughts, )