scope termination-aware verification to file download fallback (#4043)
This commit is contained in:
@@ -501,6 +501,7 @@ class ForgeAgent:
|
|||||||
task = await self.update_task_errors_from_detailed_output(task, detailed_output) # type: ignore
|
task = await self.update_task_errors_from_detailed_output(task, detailed_output) # type: ignore
|
||||||
retry = False
|
retry = False
|
||||||
|
|
||||||
|
download_detected = False
|
||||||
if task_block and task_block.complete_on_download and task.workflow_run_id:
|
if task_block and task_block.complete_on_download and task.workflow_run_id:
|
||||||
workflow_download_directory = get_path_for_workflow_download_directory(
|
workflow_download_directory = get_path_for_workflow_download_directory(
|
||||||
context.run_id if context and context.run_id else task.workflow_run_id
|
context.run_id if context and context.run_id else task.workflow_run_id
|
||||||
@@ -538,6 +539,7 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
list_files_after = list_files_after + browser_session_downloaded_files_after
|
list_files_after = list_files_after + browser_session_downloaded_files_after
|
||||||
if len(list_files_after) > len(list_files_before):
|
if len(list_files_after) > len(list_files_before):
|
||||||
|
download_detected = True
|
||||||
files_to_rename = list(set(list_files_after) - set(list_files_before))
|
files_to_rename = list(set(list_files_after) - set(list_files_before))
|
||||||
for file in files_to_rename:
|
for file in files_to_rename:
|
||||||
if file.startswith("s3://"):
|
if file.startswith("s3://"):
|
||||||
@@ -599,6 +601,28 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
return last_step, detailed_output, None
|
return last_step, detailed_output, None
|
||||||
|
|
||||||
|
if (
|
||||||
|
task_block
|
||||||
|
and isinstance(task_block, FileDownloadBlock)
|
||||||
|
and task_block.complete_on_download
|
||||||
|
and task.workflow_run_id
|
||||||
|
and not download_detected
|
||||||
|
):
|
||||||
|
handled, fallback_last_step = await self._handle_file_download_verification_fallback(
|
||||||
|
organization=organization,
|
||||||
|
task=task,
|
||||||
|
step=step,
|
||||||
|
browser_state=browser_state,
|
||||||
|
task_block=task_block,
|
||||||
|
detailed_output=detailed_output,
|
||||||
|
engine=engine,
|
||||||
|
api_key=api_key,
|
||||||
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
)
|
||||||
|
if handled and fallback_last_step:
|
||||||
|
return fallback_last_step, detailed_output, None
|
||||||
|
|
||||||
# If the step failed, mark the step as failed and retry
|
# If the step failed, mark the step as failed and retry
|
||||||
if step.status == StepStatus.failed:
|
if step.status == StepStatus.failed:
|
||||||
maybe_next_step = await self.handle_failed_step(organization, task, step)
|
maybe_next_step = await self.handle_failed_step(organization, task, step)
|
||||||
@@ -893,6 +917,172 @@ class ForgeAgent:
|
|||||||
@TraceManager.traced_async(
|
@TraceManager.traced_async(
|
||||||
ignore_inputs=["browser_state", "organization", "task_block", "cua_response", "llm_caller"]
|
ignore_inputs=["browser_state", "organization", "task_block", "cua_response", "llm_caller"]
|
||||||
)
|
)
|
||||||
|
async def _handle_file_download_verification_fallback(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
organization: Organization,
|
||||||
|
task: Task,
|
||||||
|
step: Step,
|
||||||
|
browser_state: BrowserState,
|
||||||
|
task_block: FileDownloadBlock,
|
||||||
|
detailed_output: DetailedAgentStepOutput | None,
|
||||||
|
engine: RunEngine,
|
||||||
|
api_key: str | None,
|
||||||
|
close_browser_on_completion: bool,
|
||||||
|
browser_session_id: str | None,
|
||||||
|
) -> tuple[bool, Step | None]:
|
||||||
|
if detailed_output is None or detailed_output.scraped_page is None:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
distinct_id = task.workflow_run_id if task.workflow_run_id else task.task_id
|
||||||
|
use_termination_prompt = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
|
||||||
|
"USE_TERMINATION_AWARE_COMPLETE_VERIFICATION",
|
||||||
|
distinct_id,
|
||||||
|
properties={"organization_id": task.organization_id},
|
||||||
|
)
|
||||||
|
except Exception as error: # pragma: no cover - defensive logging
|
||||||
|
LOG.warning(
|
||||||
|
"Failed to check USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment; skipping download fallback verification",
|
||||||
|
task_id=task.task_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
error=str(error),
|
||||||
|
)
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if not use_termination_prompt:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
page = await browser_state.get_working_page()
|
||||||
|
if page is None:
|
||||||
|
page = await browser_state.must_get_working_page()
|
||||||
|
except Exception:
|
||||||
|
LOG.warning(
|
||||||
|
"File download fallback verification could not fetch working page, skipping verification",
|
||||||
|
task_id=task.task_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
fallback_action = await self.check_user_goal_complete(
|
||||||
|
page=page,
|
||||||
|
scraped_page=detailed_output.scraped_page,
|
||||||
|
task=task,
|
||||||
|
step=step,
|
||||||
|
task_block=task_block,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning(
|
||||||
|
"File download fallback verification failed, continuing with standard flow",
|
||||||
|
task_id=task.task_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if fallback_action is None:
|
||||||
|
LOG.info(
|
||||||
|
"File download fallback verification completed with continue status",
|
||||||
|
task_id=task.task_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
)
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"File download fallback verification returned decisive action",
|
||||||
|
task_id=task.task_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
action_type=fallback_action.action_type if isinstance(fallback_action, Action) else "unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
if step.output is None:
|
||||||
|
step.output = AgentStepOutput(action_results=[], actions_and_results=[], errors=[])
|
||||||
|
if step.output.action_results is None:
|
||||||
|
step.output.action_results = []
|
||||||
|
if step.output.actions_and_results is None:
|
||||||
|
step.output.actions_and_results = []
|
||||||
|
if step.output.errors is None:
|
||||||
|
step.output.errors = []
|
||||||
|
if detailed_output.actions_and_results is None:
|
||||||
|
detailed_output.actions_and_results = []
|
||||||
|
|
||||||
|
persisted_action = cast(Action, fallback_action)
|
||||||
|
if isinstance(persisted_action, (CompleteAction, TerminateAction)):
|
||||||
|
persisted_action.organization_id = task.organization_id
|
||||||
|
persisted_action.workflow_run_id = task.workflow_run_id
|
||||||
|
persisted_action.task_id = task.task_id
|
||||||
|
persisted_action.step_id = step.step_id
|
||||||
|
persisted_action.step_order = step.order
|
||||||
|
persisted_action.action_order = len(step.output.actions_and_results)
|
||||||
|
|
||||||
|
action_results = await ActionHandler.handle_action(
|
||||||
|
detailed_output.scraped_page,
|
||||||
|
task,
|
||||||
|
step,
|
||||||
|
page,
|
||||||
|
persisted_action,
|
||||||
|
)
|
||||||
|
await self.record_artifacts_after_action(task, step, browser_state, engine)
|
||||||
|
|
||||||
|
step.output.action_results.extend(action_results)
|
||||||
|
step.output.actions_and_results.append((persisted_action, action_results))
|
||||||
|
detailed_output.actions_and_results.append((persisted_action, action_results))
|
||||||
|
if isinstance(persisted_action, DecisiveAction) and persisted_action.errors:
|
||||||
|
step.output.errors.extend(persisted_action.errors)
|
||||||
|
|
||||||
|
if isinstance(persisted_action, TerminateAction):
|
||||||
|
LOG.warning(
|
||||||
|
"File download fallback verification determined workflow should terminate",
|
||||||
|
task_id=task.task_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
reasoning=persisted_action.reasoning,
|
||||||
|
)
|
||||||
|
last_step = await self.update_step(step, output=step.output, is_last=True)
|
||||||
|
task_errors = None
|
||||||
|
if persisted_action.errors:
|
||||||
|
task_errors = [error.model_dump() for error in persisted_action.errors]
|
||||||
|
failure_reason = persisted_action.reasoning
|
||||||
|
if persisted_action.errors:
|
||||||
|
failure_reason = "; ".join(error.reasoning for error in persisted_action.errors)
|
||||||
|
updated_task = await self.update_task(
|
||||||
|
task,
|
||||||
|
status=TaskStatus.terminated,
|
||||||
|
failure_reason=failure_reason,
|
||||||
|
errors=task_errors,
|
||||||
|
)
|
||||||
|
await self.clean_up_task(
|
||||||
|
task=updated_task,
|
||||||
|
last_step=last_step,
|
||||||
|
api_key=api_key,
|
||||||
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
)
|
||||||
|
return True, last_step
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"File download fallback verification marked task as complete",
|
||||||
|
task_id=task.task_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
)
|
||||||
|
last_step = await self.update_step(step, output=step.output, is_last=True)
|
||||||
|
extracted_information = await self.get_extracted_information_for_task(task)
|
||||||
|
updated_task = await self.update_task(
|
||||||
|
task,
|
||||||
|
status=TaskStatus.completed,
|
||||||
|
extracted_information=extracted_information,
|
||||||
|
)
|
||||||
|
await self.clean_up_task(
|
||||||
|
task=updated_task,
|
||||||
|
last_step=last_step,
|
||||||
|
api_key=api_key,
|
||||||
|
close_browser_on_completion=close_browser_on_completion,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
|
)
|
||||||
|
return True, last_step
|
||||||
|
|
||||||
async def agent_step(
|
async def agent_step(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
@@ -2024,7 +2214,14 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def complete_verify(
|
async def complete_verify(
|
||||||
self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step, task_block: BaseTaskBlock | None = None
|
self,
|
||||||
|
page: Page,
|
||||||
|
scraped_page: ScrapedPage,
|
||||||
|
task: Task,
|
||||||
|
step: Step,
|
||||||
|
task_block: BaseTaskBlock | None = None,
|
||||||
|
*,
|
||||||
|
use_termination_prompt: bool = False,
|
||||||
) -> CompleteVerifyResult:
|
) -> CompleteVerifyResult:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Checking if user goal is achieved after re-scraping the page",
|
"Checking if user goal is achieved after re-scraping the page",
|
||||||
@@ -2042,35 +2239,6 @@ class ForgeAgent:
|
|||||||
if task.include_action_history_in_verification:
|
if task.include_action_history_in_verification:
|
||||||
actions_and_results_str = await self._get_action_results(task, current_step=step)
|
actions_and_results_str = await self._get_action_results(task, current_step=step)
|
||||||
|
|
||||||
# Check if we should use the termination-aware prompt (experiment)
|
|
||||||
# Only enabled for file download blocks
|
|
||||||
use_termination_prompt = False
|
|
||||||
is_file_download_block = task_block is not None and isinstance(task_block, FileDownloadBlock)
|
|
||||||
|
|
||||||
if is_file_download_block:
|
|
||||||
try:
|
|
||||||
distinct_id = task.workflow_run_id if task.workflow_run_id else task.task_id
|
|
||||||
use_termination_prompt = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
|
|
||||||
"USE_TERMINATION_AWARE_COMPLETE_VERIFICATION",
|
|
||||||
distinct_id,
|
|
||||||
properties={"organization_id": task.organization_id},
|
|
||||||
)
|
|
||||||
if use_termination_prompt:
|
|
||||||
LOG.info(
|
|
||||||
"Experiment enabled: using termination-aware complete verification prompt for file download block",
|
|
||||||
task_id=task.task_id,
|
|
||||||
workflow_run_id=task.workflow_run_id,
|
|
||||||
organization_id=task.organization_id,
|
|
||||||
block_type="file_download",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
LOG.warning(
|
|
||||||
"Failed to check USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment; using legacy behavior",
|
|
||||||
task_id=task.task_id,
|
|
||||||
workflow_run_id=task.workflow_run_id,
|
|
||||||
error=str(e),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Select the appropriate template based on experiment
|
# Select the appropriate template based on experiment
|
||||||
template_name = "check-user-goal-with-termination" if use_termination_prompt else "check-user-goal"
|
template_name = "check-user-goal-with-termination" if use_termination_prompt else "check-user-goal"
|
||||||
prompt_name = "check-user-goal-with-termination" if use_termination_prompt else "check-user-goal"
|
prompt_name = "check-user-goal-with-termination" if use_termination_prompt else "check-user-goal"
|
||||||
@@ -2133,7 +2301,14 @@ class ForgeAgent:
|
|||||||
return CompleteVerifyResult.model_validate(verification_result)
|
return CompleteVerifyResult.model_validate(verification_result)
|
||||||
|
|
||||||
async def check_user_goal_complete(
|
async def check_user_goal_complete(
|
||||||
self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step, task_block: BaseTaskBlock | None = None
|
self,
|
||||||
|
page: Page,
|
||||||
|
scraped_page: ScrapedPage,
|
||||||
|
task: Task,
|
||||||
|
step: Step,
|
||||||
|
task_block: BaseTaskBlock | None = None,
|
||||||
|
*,
|
||||||
|
use_termination_prompt: bool = False,
|
||||||
) -> CompleteAction | TerminateAction | None:
|
) -> CompleteAction | TerminateAction | None:
|
||||||
try:
|
try:
|
||||||
verification_result = await self.complete_verify(
|
verification_result = await self.complete_verify(
|
||||||
@@ -2142,17 +2317,24 @@ class ForgeAgent:
|
|||||||
task=task,
|
task=task,
|
||||||
step=step,
|
step=step,
|
||||||
task_block=task_block,
|
task_block=task_block,
|
||||||
|
use_termination_prompt=use_termination_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if we should terminate instead of complete
|
# Check if we should terminate instead of complete
|
||||||
# Note: This requires the USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment to be enabled
|
|
||||||
if verification_result.is_terminate:
|
if verification_result.is_terminate:
|
||||||
LOG.warning(
|
if use_termination_prompt:
|
||||||
"Periodic verification determined task should terminate (termination-aware experiment)",
|
LOG.warning(
|
||||||
workflow_run_id=task.workflow_run_id,
|
"Periodic verification determined task should terminate (termination-aware experiment)",
|
||||||
thoughts=verification_result.thoughts,
|
workflow_run_id=task.workflow_run_id,
|
||||||
status=verification_result.status if verification_result.status else "legacy",
|
thoughts=verification_result.thoughts,
|
||||||
)
|
status=verification_result.status if verification_result.status else "legacy",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
"Periodic verification determined task should terminate",
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
thoughts=verification_result.thoughts,
|
||||||
|
)
|
||||||
return TerminateAction(
|
return TerminateAction(
|
||||||
reasoning=verification_result.thoughts,
|
reasoning=verification_result.thoughts,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user