parallel check user goal xp (#3873)

This commit is contained in:
pedrohsdb
2025-10-31 12:19:50 -07:00
committed by GitHub
parent 0e0ae81693
commit 06bb9efb4a
2 changed files with 340 additions and 26 deletions

View File

@@ -613,6 +613,10 @@ class ForgeAgent:
step=step,
page=await browser_state.get_working_page(),
task_block=task_block,
browser_state=browser_state,
scraped_page=detailed_output.scraped_page if detailed_output else None,
engine=engine,
complete_verification=complete_verification,
)
if is_task_completed is not None and maybe_last_step:
last_step = maybe_last_step
@@ -1303,7 +1307,17 @@ class ForgeAgent:
task.task_id,
properties={"task_url": task.url, "organization_id": task.organization_id},
)
if not disable_user_goal_check:
# Check if parallel verification is enabled
distinct_id = task.workflow_run_id if task.workflow_run_id else task.task_id
enable_parallel_verification = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
"ENABLE_PARALLEL_USER_GOAL_CHECK",
distinct_id,
properties={"organization_id": task.organization_id, "task_url": task.url},
)
if not disable_user_goal_check and not enable_parallel_verification:
# Standard synchronous verification
working_page = await browser_state.must_get_working_page()
complete_action = await self.check_user_goal_complete(
page=working_page,
@@ -1324,6 +1338,13 @@ class ForgeAgent:
)
detailed_agent_step_output.actions_and_results.append((complete_action, complete_results))
await self.record_artifacts_after_action(task, step, browser_state, engine)
elif enable_parallel_verification:
# Parallel verification enabled - defer check to handle_completed_step
LOG.info(
"Parallel verification enabled, deferring user goal check to handle_completed_step",
step_id=step.step_id,
task_id=task.task_id,
)
# if the last action is complete and is successful, check if there's a data extraction goal
# if task has navigation goal and extraction goal at the same time, handle ExtractAction before marking step as completed
@@ -1676,6 +1697,49 @@ class ForgeAgent:
return actions
async def _pre_scrape_for_next_step(
self,
task: Task,
step: Step,
browser_state: BrowserState,
engine: RunEngine,
) -> ScrapedPage | None:
"""
Pre-scrape the page for the next step while verification is running.
This is the expensive operation (5-10 seconds) that we want to run in parallel.
"""
try:
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
draw_boxes = True
scroll = True
if engine in CUA_ENGINES:
max_screenshot_number = 1
draw_boxes = False
scroll = False
scraped_page = await scrape_website(
browser_state,
task.url,
app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step),
scrape_exclude=app.scrape_exclude,
max_screenshot_number=max_screenshot_number,
draw_boxes=draw_boxes,
scroll=scroll,
)
LOG.info(
"Pre-scraped page for next step in parallel with verification",
step_id=step.step_id,
num_elements=len(scraped_page.elements) if scraped_page else 0,
)
return scraped_page
except Exception:
LOG.warning(
"Failed to pre-scrape for next step, will re-scrape on next step execution",
step_id=step.step_id,
exc_info=True,
)
return None
async def complete_verify(
self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step
) -> CompleteVerifyResult:
@@ -1929,37 +1993,63 @@ class ForgeAgent:
browser_state: BrowserState,
engine: RunEngine,
) -> tuple[ScrapedPage, str, bool]:
# start the async tasks while running scrape_website
if engine not in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
# Scrape the web page and get the screenshot and the elements
# HACK: try scrape_website three time to handle screenshot timeout
# first time: normal scrape to take screenshot
# second time: try again the normal scrape, (stopping window loading before scraping barely helps, but causing problem)
# third time: reload the page before scraping
# Check if we have pre-scraped data from parallel verification optimization
context = skyvern_context.current()
scraped_page: ScrapedPage | None = None
extract_action_prompt = ""
use_caching = False
for idx, scrape_type in enumerate(SCRAPE_TYPE_ORDER):
try:
scraped_page = await self._scrape_with_type(
task=task,
step=step,
browser_state=browser_state,
scrape_type=scrape_type,
engine=engine,
if (
context
and context.next_step_pre_scraped_data
and context.next_step_pre_scraped_data.get("step_id") == step.step_id
):
scraped_page = context.next_step_pre_scraped_data.get("scraped_page")
if scraped_page:
timestamp = context.next_step_pre_scraped_data.get("timestamp")
age_seconds = (datetime.now(UTC) - timestamp).total_seconds() if timestamp else 0
LOG.info(
"Using pre-scraped data from parallel verification optimization",
step_id=step.step_id,
num_elements=len(scraped_page.elements),
age_seconds=age_seconds,
)
break
except (FailedToTakeScreenshot, ScrapingFailed) as e:
if idx < len(SCRAPE_TYPE_ORDER) - 1:
continue
LOG.exception(f"{e.__class__.__name__} happened in two normal attempts and reload-page retry")
raise e
# Clear the cached data
context.next_step_pre_scraped_data = None
# If we don't have pre-scraped data, scrape normally
if scraped_page is None:
# start the async tasks while running scrape_website
if engine not in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
# Scrape the web page and get the screenshot and the elements
# HACK: try scrape_website three time to handle screenshot timeout
# first time: normal scrape to take screenshot
# second time: try again the normal scrape, (stopping window loading before scraping barely helps, but causing problem)
# third time: reload the page before scraping
extract_action_prompt = ""
use_caching = False
for idx, scrape_type in enumerate(SCRAPE_TYPE_ORDER):
try:
scraped_page = await self._scrape_with_type(
task=task,
step=step,
browser_state=browser_state,
scrape_type=scrape_type,
engine=engine,
)
break
except (FailedToTakeScreenshot, ScrapingFailed) as e:
if idx < len(SCRAPE_TYPE_ORDER) - 1:
continue
LOG.exception(f"{e.__class__.__name__} happened in two normal attempts and reload-page retry")
raise e
if scraped_page is None:
raise EmptyScrapePage()
extract_action_prompt = ""
use_caching = False
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.HTML_SCRAPE,
@@ -2805,6 +2895,180 @@ class ForgeAgent:
**updates,
)
async def _handle_completed_step_with_parallel_verification(
self,
organization: Organization,
task: Task,
step: Step,
page: Page | None,
browser_state: BrowserState,
scraped_page: ScrapedPage,
engine: RunEngine,
task_block: BaseTaskBlock | None = None,
) -> tuple[bool | None, Step | None, Step | None]:
"""
Handle completed step with parallel verification optimization.
Runs two tasks in parallel:
1. Verify if user goal is complete (check-user-goal)
2. Pre-scrape page for next step
If goal is complete, cancel pre-scraping and mark task done.
If goal not complete, use pre-scraped data for next step execution.
Note: This should only be called when verification is needed (i.e., when
the standard flow would have called check_user_goal_complete in agent_step).
"""
LOG.info(
"Starting parallel user goal verification optimization",
step_id=step.step_id,
task_id=task.task_id,
)
# Task 1: Verify user goal (typically 2-5 seconds)
verification_task = asyncio.create_task(
self.check_user_goal_complete(
page=page,
scraped_page=scraped_page,
task=task,
step=step,
),
name=f"verify_goal_{step.step_id}",
)
# Task 2: Pre-scrape for next step (typically 5-10 seconds)
pre_scrape_task = asyncio.create_task(
self._pre_scrape_for_next_step(
task=task,
step=step,
browser_state=browser_state,
engine=engine,
),
name=f"pre_scrape_{step.step_id}",
)
# Wait for verification to complete first (faster of the two)
try:
complete_action = await verification_task
except Exception:
LOG.warning(
"User goal verification failed in parallel mode, will continue with next step",
step_id=step.step_id,
exc_info=True,
)
complete_action = None
if complete_action is not None:
# Goal achieved! Cancel the pre-scraping task
LOG.info(
"Parallel verification: goal achieved, cancelling pre-scraping",
step_id=step.step_id,
task_id=task.task_id,
)
pre_scrape_task.cancel()
try:
await pre_scrape_task # Clean up the cancelled task
except asyncio.CancelledError:
LOG.debug("Pre-scraping cancelled successfully", step_id=step.step_id)
except Exception:
LOG.debug("Pre-scraping task cleanup failed", step_id=step.step_id, exc_info=True)
# Mark task as complete
# Note: Step is already marked as completed by agent_step
# We don't add the complete action to the step output since the step is already finalized
LOG.info(
"Parallel verification: goal achieved, marking task as complete",
step_id=step.step_id,
task_id=task.task_id,
)
last_step = await self.update_step(step, is_last=True)
extracted_information = await self.get_extracted_information_for_task(task)
await self.update_task(
task,
status=TaskStatus.completed,
extracted_information=extracted_information,
)
return True, last_step, None
else:
# Goal not achieved - wait for pre-scraping to complete
LOG.info(
"Parallel verification: goal not achieved, using pre-scraped data for next step",
step_id=step.step_id,
task_id=task.task_id,
)
try:
pre_scraped_page = await pre_scrape_task
except Exception:
LOG.warning(
"Pre-scraping failed, next step will re-scrape",
step_id=step.step_id,
exc_info=True,
)
pre_scraped_page = None
# Check max steps before creating next step
context = skyvern_context.current()
override_max_steps_per_run = context.max_steps_override if context else None
max_steps_per_run = (
override_max_steps_per_run
or task.max_steps_per_run
or organization.max_steps_per_run
or settings.MAX_STEPS_PER_RUN
)
if step.order + 1 >= max_steps_per_run:
LOG.info(
"Step completed but max steps reached, marking task as failed",
step_order=step.order,
step_retry=step.retry_index,
max_steps=max_steps_per_run,
)
last_step = await self.update_step(step, is_last=True)
generated_failure_reason = await self.summary_failure_reason_for_max_steps(
organization=organization,
task=task,
step=step,
page=page,
)
failure_reason = f"Reached the maximum steps ({max_steps_per_run}). Possible failure reasons: {generated_failure_reason.reasoning}"
errors = [ReachMaxStepsError().model_dump()] + [
error.model_dump() for error in generated_failure_reason.errors
]
await self.update_task(
task,
status=TaskStatus.failed,
failure_reason=failure_reason,
errors=errors,
)
return False, last_step, None
# Create next step
next_step = await app.DATABASE.create_step(
task_id=task.task_id,
order=step.order + 1,
retry_index=0,
organization_id=task.organization_id,
)
# Store pre-scraped data in context for next step to use
if pre_scraped_page:
context = skyvern_context.ensure_context()
context.next_step_pre_scraped_data = {
"step_id": next_step.step_id,
"scraped_page": pre_scraped_page,
"timestamp": datetime.now(UTC),
}
LOG.info(
"Stored pre-scraped data for next step",
step_id=next_step.step_id,
num_elements=len(pre_scraped_page.elements),
)
return None, None, next_step
async def handle_failed_step(self, organization: Organization, task: Task, step: Step) -> Step | None:
max_retries_per_step = (
organization.max_retries_per_step
@@ -3087,7 +3351,53 @@ class ForgeAgent:
step: Step,
page: Page | None,
task_block: BaseTaskBlock | None = None,
browser_state: BrowserState | None = None,
scraped_page: ScrapedPage | None = None,
engine: RunEngine = RunEngine.skyvern_v1,
complete_verification: bool = True,
) -> tuple[bool | None, Step | None, Step | None]:
# Check if parallel verification should be used
# Only use it when we have the required data AND when verification would normally happen
should_verify = (
complete_verification
and not step.is_goal_achieved()
and not step.is_terminated()
and not isinstance(task_block, ActionBlock)
and (task.navigation_goal or task.complete_criterion)
)
if should_verify and browser_state and scraped_page:
try:
distinct_id = task.workflow_run_id if task.workflow_run_id else task.task_id
enable_parallel_verification = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
"ENABLE_PARALLEL_USER_GOAL_CHECK",
distinct_id,
properties={"organization_id": task.organization_id, "task_url": task.url},
)
if enable_parallel_verification:
LOG.info(
"Parallel verification enabled, using optimized flow",
step_id=step.step_id,
task_id=task.task_id,
)
return await self._handle_completed_step_with_parallel_verification(
organization=organization,
task=task,
step=step,
page=page,
browser_state=browser_state,
scraped_page=scraped_page,
engine=engine,
task_block=task_block,
)
except Exception:
LOG.warning(
"Failed to check parallel verification feature flag, using standard flow",
step_id=step.step_id,
exc_info=True,
)
if step.is_goal_achieved():
LOG.info(
"Step completed and goal achieved, marking task as completed",

View File

@@ -55,6 +55,10 @@ class SkyvernContext:
# next blocks won't consider the page as a magic link page
magic_link_pages: dict[str, Page] = field(default_factory=dict)
# parallel verification optimization
# stores pre-scraped data for next step to avoid re-scraping
next_step_pre_scraped_data: dict[str, Any] | None = None
"""
Example output value:
{"loop_value": "str", "output_parameter": "the key of the parameter", "output_value": Any}