parallelize goal check within task (#3997)

This commit is contained in:
pedrohsdb
2025-11-13 17:18:32 -08:00
committed by GitHub
parent a95837783a
commit b7e28b075c
5 changed files with 675 additions and 330 deletions

View File

@@ -82,6 +82,11 @@ class MissingElement(SkyvernException):
)
class MissingExtractActionsResponse(SkyvernException):
def __init__(self) -> None:
super().__init__("extract-actions response missing")
class MultipleElementsFound(SkyvernException):
def __init__(self, num: int, selector: str | None = None, element_id: str | None = None):
super().__init__(

View File

@@ -6,6 +6,7 @@ import random
import re
import string
from asyncio.exceptions import CancelledError
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Tuple, cast
@@ -48,6 +49,7 @@ from skyvern.exceptions import (
InvalidTaskStatusTransition,
InvalidWorkflowTaskURLState,
MissingBrowserStatePage,
MissingExtractActionsResponse,
NoTOTPVerificationCodeFound,
ScrapingFailed,
SkyvernException,
@@ -81,7 +83,7 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step, StepStatus
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus
@@ -136,6 +138,15 @@ EXTRACT_ACTION_PROMPT_NAME = "extract-actions"
EXTRACT_ACTION_CACHE_KEY_PREFIX = f"{EXTRACT_ACTION_TEMPLATE}-static"
@dataclass
class SpeculativePlan:
scraped_page: ScrapedPage
extract_action_prompt: str
use_caching: bool
llm_json_response: dict[str, Any] | None
llm_metadata: SpeculativeLLMMetadata | None = None
class ActionLinkedNode:
def __init__(self, action: Action) -> None:
self.action = action
@@ -915,19 +926,35 @@ class ForgeAgent:
organization=organization, task=task, step=step, browser_state=browser_state
)
(
scraped_page,
extract_action_prompt,
use_caching,
) = await self.build_and_record_step_prompt(
task,
step,
browser_state,
engine,
)
speculative_plan: SpeculativePlan | None = None
reuse_speculative_llm_response = False
speculative_llm_metadata: SpeculativeLLMMetadata | None = None
if context:
speculative_plan = context.speculative_plans.pop(step.step_id, None)
if speculative_plan:
step.is_speculative = False
scraped_page = speculative_plan.scraped_page
extract_action_prompt = speculative_plan.extract_action_prompt
use_caching = speculative_plan.use_caching
json_response = speculative_plan.llm_json_response
reuse_speculative_llm_response = json_response is not None
speculative_llm_metadata = speculative_plan.llm_metadata
else:
(
scraped_page,
extract_action_prompt,
use_caching,
) = await self.build_and_record_step_prompt(
task,
step,
browser_state,
engine,
)
json_response = None
detailed_agent_step_output.scraped_page = scraped_page
detailed_agent_step_output.extract_action_prompt = extract_action_prompt
json_response = None
actions: list[Action]
if engine == RunEngine.openai_cua:
@@ -986,12 +1013,20 @@ class ForgeAgent:
if context:
context.use_prompt_caching = True
json_response = await llm_api_handler(
prompt=extract_action_prompt,
prompt_name="extract-actions",
step=step,
screenshots=scraped_page.screenshots,
)
if not reuse_speculative_llm_response:
json_response = await llm_api_handler(
prompt=extract_action_prompt,
prompt_name="extract-actions",
step=step,
screenshots=scraped_page.screenshots,
)
else:
LOG.debug(
"Using speculative extract-actions response",
step_id=step.step_id,
)
if json_response is None:
raise MissingExtractActionsResponse()
try:
otp_json_response, otp_actions = await self.handle_potential_OTP_actions(
task, step, scraped_page, browser_state, json_response
@@ -1035,6 +1070,14 @@ class ForgeAgent:
)
]
if reuse_speculative_llm_response and speculative_llm_metadata:
await self._persist_speculative_llm_metadata(
step,
speculative_llm_metadata,
screenshots=scraped_page.screenshots,
)
speculative_llm_metadata = None
detailed_agent_step_output.actions = actions
if len(actions) == 0:
LOG.info(
@@ -1308,6 +1351,7 @@ class ForgeAgent:
break
task_completes_on_download = task_block and task_block.complete_on_download and task.workflow_run_id
enable_parallel_verification = False
if (
not has_decisive_action
and not task_completes_on_download
@@ -1385,6 +1429,8 @@ class ForgeAgent:
status=StepStatus.completed,
output=detailed_agent_step_output.to_agent_step_output(),
)
if enable_parallel_verification:
completed_step.speculative_original_status = StepStatus.completed
return completed_step, detailed_agent_step_output.get_clean_detailed_output()
except CancelledError:
LOG.exception(
@@ -1748,51 +1794,229 @@ class ForgeAgent:
return draw_boxes
async def _pre_scrape_for_next_step(
async def _speculate_next_step_plan(
self,
task: Task,
step: Step,
current_step: Step,
next_step: Step,
browser_state: BrowserState,
engine: RunEngine,
) -> ScrapedPage | None:
"""
Pre-scrape the page for the next step while verification is running.
This is the expensive operation (5-10 seconds) that we want to run in parallel.
"""
try:
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
draw_boxes = True
scroll = True
if engine in CUA_ENGINES:
max_screenshot_number = 1
draw_boxes = False
scroll = False
# Check PostHog feature flag to skip screenshot annotations
draw_boxes = await self._should_skip_screenshot_annotations(task, draw_boxes)
scraped_page = await scrape_website(
browser_state,
task.url,
app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step),
scrape_exclude=app.scrape_exclude,
max_screenshot_number=max_screenshot_number,
draw_boxes=draw_boxes,
scroll=scroll,
)
) -> SpeculativePlan | None:
if engine in CUA_ENGINES:
LOG.info(
"Pre-scraped page for next step in parallel with verification",
step_id=step.step_id,
num_elements=len(scraped_page.elements) if scraped_page else 0,
"Skipping speculative extract-actions for CUA engine",
step_id=current_step.step_id,
task_id=task.task_id,
)
return None
try:
next_step.is_speculative = True
scraped_page, extract_action_prompt, use_caching = await self.build_and_record_step_prompt(
task,
next_step,
browser_state,
engine,
persist_artifacts=False,
)
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
task.llm_key,
default=app.LLM_API_HANDLER,
)
llm_json_response = await llm_api_handler(
prompt=extract_action_prompt,
prompt_name="extract-actions",
step=next_step,
screenshots=scraped_page.screenshots,
)
LOG.info(
"Speculative extract-actions completed",
current_step_id=current_step.step_id,
synthetic_step_id=next_step.step_id,
)
metadata_copy = None
if next_step.speculative_llm_metadata is not None:
metadata_copy = next_step.speculative_llm_metadata.model_copy()
next_step.speculative_llm_metadata = None
next_step.is_speculative = False
return SpeculativePlan(
scraped_page=scraped_page,
extract_action_prompt=extract_action_prompt,
use_caching=use_caching,
llm_json_response=llm_json_response,
llm_metadata=metadata_copy,
)
return scraped_page
except Exception:
LOG.warning(
"Failed to pre-scrape for next step, will re-scrape on next step execution",
"Failed to run speculative extract-actions",
step_id=current_step.step_id,
exc_info=True,
)
next_step.is_speculative = False
return None
async def _persist_speculative_llm_metadata(
self,
step: Step,
metadata: SpeculativeLLMMetadata,
*,
screenshots: list[bytes] | None = None,
) -> None:
if not metadata:
return
LOG.debug("Persisting speculative LLM metadata")
if metadata.prompt:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
)
if metadata.llm_request_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
)
if metadata.llm_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
)
if metadata.parsed_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
)
if metadata.rendered_response_json:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=metadata.rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step,
)
incremental_cost = metadata.llm_cost if metadata.llm_cost and metadata.llm_cost > 0 else None
incremental_input_tokens = (
metadata.input_tokens if metadata.input_tokens and metadata.input_tokens > 0 else None
)
incremental_output_tokens = (
metadata.output_tokens if metadata.output_tokens and metadata.output_tokens > 0 else None
)
incremental_reasoning_tokens = (
metadata.reasoning_tokens if metadata.reasoning_tokens and metadata.reasoning_tokens > 0 else None
)
incremental_cached_tokens = (
metadata.cached_tokens if metadata.cached_tokens and metadata.cached_tokens > 0 else None
)
if (
incremental_cost is not None
or incremental_input_tokens is not None
or incremental_output_tokens is not None
or incremental_reasoning_tokens is not None
or incremental_cached_tokens is not None
):
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=incremental_cost,
incremental_input_tokens=incremental_input_tokens,
incremental_output_tokens=incremental_output_tokens,
incremental_reasoning_tokens=incremental_reasoning_tokens,
incremental_cached_tokens=incremental_cached_tokens,
)
if incremental_input_tokens:
step.input_token_count += incremental_input_tokens
if incremental_output_tokens:
step.output_token_count += incremental_output_tokens
if incremental_reasoning_tokens:
step.reasoning_token_count = (step.reasoning_token_count or 0) + incremental_reasoning_tokens
if incremental_cached_tokens:
step.cached_token_count = (step.cached_token_count or 0) + incremental_cached_tokens
if incremental_cost:
step.step_cost += incremental_cost
step.speculative_llm_metadata = None
async def _persist_speculative_metadata_for_discarded_plan(
self,
step: Step,
speculative_task: asyncio.Future[SpeculativePlan | None],
*,
cancel_step: bool = False,
) -> None:
try:
plan = await asyncio.shield(speculative_task)
except CancelledError:
LOG.debug(
"Speculative extract-actions cancelled before metadata persistence",
step_id=step.step_id,
)
step.is_speculative = False
if cancel_step:
await self._cancel_speculative_step(step)
return
except Exception:
LOG.debug(
"Speculative extract-actions failed before metadata persistence",
step_id=step.step_id,
exc_info=True,
)
step.is_speculative = False
if cancel_step:
await self._cancel_speculative_step(step)
return
if not plan or not plan.llm_metadata:
step.is_speculative = False
if cancel_step:
await self._cancel_speculative_step(step)
return
try:
await self._persist_speculative_llm_metadata(
step,
plan.llm_metadata,
)
step.is_speculative = False
if cancel_step:
await self._cancel_speculative_step(step)
except Exception:
LOG.warning(
"Failed to persist speculative llm metadata for discarded plan",
step_id=step.step_id,
exc_info=True,
)
async def _cancel_speculative_step(self, step: Step) -> None:
if step.status == StepStatus.canceled:
return
try:
updated_step = await self.update_step(step, status=StepStatus.canceled)
step.status = updated_step.status
step.is_speculative = False
except Exception:
LOG.warning(
"Failed to cancel speculative step",
step_id=step.step_id,
exc_info=True,
)
return None
async def complete_verify(
self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step, task_block: BaseTaskBlock | None = None
@@ -2099,6 +2323,8 @@ class ForgeAgent:
step: Step,
browser_state: BrowserState,
engine: RunEngine,
*,
persist_artifacts: bool = True,
) -> tuple[ScrapedPage, str, bool]:
# Check if we have pre-scraped data from parallel verification optimization
context = skyvern_context.current()
@@ -2178,11 +2404,12 @@ class ForgeAgent:
extract_action_prompt = ""
use_caching = False
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.HTML_SCRAPE,
data=scraped_page.html.encode(),
)
if persist_artifacts:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.HTML_SCRAPE,
data=scraped_page.html.encode(),
)
LOG.info(
"Scraped website",
step_order=step.order,
@@ -2191,6 +2418,7 @@ class ForgeAgent:
url=task.url,
)
# TODO: we only use HTML element for now, introduce a way to switch in the future
enable_speed_optimizations = getattr(context, "enable_speed_optimizations", False)
element_tree_format = ElementTreeFormat.HTML
# OPTIMIZATION: Use economy tree (skip SVGs) when ENABLE_SPEED_OPTIMIZATIONS is enabled
@@ -2248,31 +2476,32 @@ class ForgeAgent:
expire_verification_code=True,
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_CSS_MAP,
data=json.dumps(scraped_page.id_to_css_dict, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_FRAME_MAP,
data=json.dumps(scraped_page.id_to_frame_dict, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE,
data=json.dumps(scraped_page.element_tree, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED,
data=json.dumps(scraped_page.element_tree_trimmed, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_IN_PROMPT,
data=element_tree_in_prompt.encode(),
)
if persist_artifacts:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_CSS_MAP,
data=json.dumps(scraped_page.id_to_css_dict, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_FRAME_MAP,
data=json.dumps(scraped_page.id_to_frame_dict, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE,
data=json.dumps(scraped_page.element_tree, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED,
data=json.dumps(scraped_page.element_tree_trimmed, indent=2).encode(),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_IN_PROMPT,
data=element_tree_in_prompt.encode(),
)
return scraped_page, extract_action_prompt, use_caching
@@ -2480,6 +2709,16 @@ class ForgeAgent:
task_llm_key=task.llm_key,
effective_llm_key=effective_llm_key,
)
enable_speed_optimizations = context.enable_speed_optimizations
element_tree_format = ElementTreeFormat.HTML
if enable_speed_optimizations:
if step.retry_index == 0:
elements_for_prompt = scraped_page.build_economy_elements_tree(element_tree_format)
else:
elements_for_prompt = scraped_page.build_element_tree(element_tree_format)
else:
elements_for_prompt = scraped_page.build_element_tree(element_tree_format)
if template == EXTRACT_ACTION_TEMPLATE and cache_enabled:
try:
# Try to load split templates for caching
@@ -2501,7 +2740,11 @@ class ForgeAgent:
"has_magic_link_page": context.has_magic_link_page(task.task_id),
}
static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs)
dynamic_prompt = prompt_engine.load_prompt(f"{template}-dynamic", **prompt_kwargs)
dynamic_prompt = prompt_engine.load_prompt(
f"{template}-dynamic",
elements=elements_for_prompt,
**prompt_kwargs,
)
# Store static prompt for caching and continue sending it alongside the dynamic section.
# Vertex explicit caching expects the static content to still be present in the request so the
@@ -3250,12 +3493,11 @@ class ForgeAgent:
the standard flow would have called check_user_goal_complete in agent_step).
"""
LOG.info(
"Starting parallel user goal verification optimization",
"Starting parallel user goal verification with speculative extract-actions",
step_id=step.step_id,
task_id=task.task_id,
)
# Task 1: Verify user goal (typically 2-5 seconds)
verification_task = asyncio.create_task(
self.check_user_goal_complete(
page=page,
@@ -3267,18 +3509,31 @@ class ForgeAgent:
name=f"verify_goal_{step.step_id}",
)
# Task 2: Pre-scrape for next step (typically 5-10 seconds)
pre_scrape_task = asyncio.create_task(
self._pre_scrape_for_next_step(
next_step = await app.DATABASE.create_step(
task_id=task.task_id,
order=step.order + 1,
retry_index=0,
organization_id=task.organization_id,
)
LOG.debug(
"Waiting before launching speculative plan",
step_id=step.step_id,
task_id=task.task_id,
)
await asyncio.sleep(1.0)
speculative_task = asyncio.create_task(
self._speculate_next_step_plan(
task=task,
step=step,
current_step=step,
next_step=next_step,
browser_state=browser_state,
engine=engine,
),
name=f"pre_scrape_{step.step_id}",
name=f"speculate_next_step_{step.step_id}",
)
# Wait for verification to complete first (faster of the two)
try:
complete_action = await verification_task
except Exception:
@@ -3290,25 +3545,15 @@ class ForgeAgent:
complete_action = None
if complete_action is not None:
# Goal achieved or should terminate! Cancel the pre-scraping task
is_terminate = isinstance(complete_action, TerminateAction)
LOG.info(
"Parallel verification: goal achieved or termination required, cancelling pre-scraping",
step_id=step.step_id,
task_id=task.task_id,
is_terminate=is_terminate,
asyncio.create_task(
self._persist_speculative_metadata_for_discarded_plan(
next_step,
speculative_task,
cancel_step=True,
)
)
pre_scrape_task.cancel()
try:
await pre_scrape_task # Clean up the cancelled task
except asyncio.CancelledError:
LOG.debug("Pre-scraping cancelled successfully", step_id=step.step_id)
except Exception:
LOG.debug("Pre-scraping task cleanup failed", step_id=step.step_id, exc_info=True)
working_page = page
if working_page is None:
working_page = await browser_state.must_get_working_page()
working_page = page or await browser_state.must_get_working_page()
if step.output is None:
step.output = AgentStepOutput(action_results=[], actions_and_results=[], errors=[])
@@ -3333,21 +3578,27 @@ class ForgeAgent:
if isinstance(persisted_action, DecisiveAction) and persisted_action.errors:
step.output.errors.extend(persisted_action.errors)
if is_terminate:
# Mark task as terminated/failed
# Note: This requires the USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment to be enabled
if isinstance(persisted_action, TerminateAction):
LOG.warning(
"Parallel verification: termination required, marking task as terminated (termination-aware experiment)",
"Parallel verification: termination required, marking task as terminated",
step_id=step.step_id,
task_id=task.task_id,
reasoning=complete_action.reasoning,
)
last_step = await self.update_step(step, output=step.output, is_last=True)
final_status = step.speculative_original_status or StepStatus.completed
step.speculative_original_status = None
step.status = final_status
last_step = await self.update_step(
step,
status=final_status,
output=step.output,
is_last=True,
)
task_errors = None
if isinstance(persisted_action, TerminateAction) and persisted_action.errors:
if persisted_action.errors:
task_errors = [error.model_dump() for error in persisted_action.errors]
failure_reason = persisted_action.reasoning
if isinstance(persisted_action, TerminateAction) and persisted_action.errors:
if persisted_action.errors:
failure_reason = "; ".join(error.reasoning for error in persisted_action.errors)
await self.update_task(
task,
@@ -3356,102 +3607,108 @@ class ForgeAgent:
errors=task_errors,
)
return True, last_step, None
else:
# Mark task as complete
# Note: Step is already marked as completed by agent_step
# We don't add the complete action to the step output since the step is already finalized
LOG.info(
"Parallel verification: goal achieved, marking task as complete",
step_id=step.step_id,
task_id=task.task_id,
)
last_step = await self.update_step(step, output=step.output, is_last=True)
extracted_information = await self.get_extracted_information_for_task(task)
await self.update_task(
task,
status=TaskStatus.completed,
extracted_information=extracted_information,
)
return True, last_step, None
else:
# Goal not achieved - wait for pre-scraping to complete
LOG.info(
"Parallel verification: goal not achieved, using pre-scraped data for next step",
"Parallel verification: goal achieved, marking task as complete",
step_id=step.step_id,
task_id=task.task_id,
)
final_status = step.speculative_original_status or StepStatus.completed
step.speculative_original_status = None
step.status = final_status
last_step = await self.update_step(
step,
status=final_status,
output=step.output,
is_last=True,
)
extracted_information = await self.get_extracted_information_for_task(task)
await self.update_task(
task,
status=TaskStatus.completed,
extracted_information=extracted_information,
)
return True, last_step, None
try:
pre_scraped_page = await pre_scrape_task
except Exception:
LOG.warning(
"Pre-scraping failed, next step will re-scrape",
step_id=step.step_id,
exc_info=True,
)
pre_scraped_page = None
LOG.info(
"Parallel verification: goal not achieved, awaiting speculative extract-actions",
step_id=step.step_id,
task_id=task.task_id,
)
# Check max steps before creating next step
context = skyvern_context.current()
override_max_steps_per_run = context.max_steps_override if context else None
max_steps_per_run = (
override_max_steps_per_run
or task.max_steps_per_run
or organization.max_steps_per_run
or settings.MAX_STEPS_PER_RUN
try:
speculative_plan = await speculative_task
except CancelledError:
LOG.debug("Speculative extract-actions cancelled after verification finished", step_id=step.step_id)
speculative_plan = None
except Exception:
LOG.warning(
"Speculative extract-actions failed, next step will run sequentially",
step_id=step.step_id,
exc_info=True,
)
speculative_plan = None
context = skyvern_context.current()
override_max_steps_per_run = context.max_steps_override if context else None
max_steps_per_run = (
override_max_steps_per_run
or task.max_steps_per_run
or organization.max_steps_per_run
or settings.MAX_STEPS_PER_RUN
)
if step.order + 1 >= max_steps_per_run:
LOG.info(
"Step completed but max steps reached, marking task as failed",
step_order=step.order,
step_retry=step.retry_index,
max_steps=max_steps_per_run,
)
final_status = step.speculative_original_status or StepStatus.completed
step.speculative_original_status = None
step.status = final_status
last_step = await self.update_step(
step,
status=final_status,
output=step.output,
is_last=True,
)
if step.order + 1 >= max_steps_per_run:
LOG.info(
"Step completed but max steps reached, marking task as failed",
step_order=step.order,
step_retry=step.retry_index,
max_steps=max_steps_per_run,
)
last_step = await self.update_step(step, is_last=True)
generated_failure_reason = await self.summary_failure_reason_for_max_steps(
organization=organization,
task=task,
step=step,
page=page,
)
failure_reason = f"Reached the maximum steps ({max_steps_per_run}). Possible failure reasons: {generated_failure_reason.reasoning}"
errors = [ReachMaxStepsError().model_dump()] + [
error.model_dump() for error in generated_failure_reason.errors
]
generated_failure_reason = await self.summary_failure_reason_for_max_steps(
organization=organization,
task=task,
step=step,
page=page,
)
failure_reason = f"Reached the maximum steps ({max_steps_per_run}). Possible failure reasons: {generated_failure_reason.reasoning}"
errors = [ReachMaxStepsError().model_dump()] + [
error.model_dump() for error in generated_failure_reason.errors
]
await self._cancel_speculative_step(next_step)
await self.update_task(
task,
status=TaskStatus.failed,
failure_reason=failure_reason,
errors=errors,
)
return False, last_step, None
await self.update_task(
task,
status=TaskStatus.failed,
failure_reason=failure_reason,
errors=errors,
)
return False, last_step, None
# Create next step
next_step = await app.DATABASE.create_step(
task_id=task.task_id,
order=step.order + 1,
retry_index=0,
organization_id=task.organization_id,
if speculative_plan:
context = skyvern_context.ensure_context()
context.speculative_plans[next_step.step_id] = speculative_plan
LOG.info(
"Stored speculative extract-actions plan for next step",
current_step_id=step.step_id,
next_step_id=next_step.step_id,
)
# Store pre-scraped data in context for next step to use
if pre_scraped_page:
context = skyvern_context.ensure_context()
context.next_step_pre_scraped_data = {
"step_id": next_step.step_id,
"scraped_page": pre_scraped_page,
"timestamp": datetime.now(UTC),
}
LOG.info(
"Stored pre-scraped data for next step",
step_id=next_step.step_id,
num_elements=len(pre_scraped_page.elements),
)
step.status = step.speculative_original_status or StepStatus.completed
step.speculative_original_status = None
return None, None, next_step
return None, None, next_step
async def handle_failed_step(self, organization: Organization, task: Task, step: Step) -> Step | None:
max_retries_per_step = (

View File

@@ -29,7 +29,7 @@ from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought
from skyvern.forge.sdk.trace import TraceManager
@@ -260,7 +260,8 @@ class LLMAPIHandlerFactory:
)
context = skyvern_context.current()
if context and len(context.hashed_href_map) > 0:
is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP,
@@ -270,14 +271,16 @@ class LLMAPIHandlerFactory:
ai_suggestion=ai_suggestion,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
task_v2=task_v2,
thought=thought,
)
llm_prompt_value = prompt
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_prompt_value.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
task_v2=task_v2,
thought=thought,
)
# Build messages and apply caching in one step
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
@@ -330,21 +333,22 @@ class LLMAPIHandlerFactory:
cache_attached=True,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
"vertex_cache_attached": vertex_cache_attached,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
llm_request_payload = {
"model": llm_key,
"messages": messages,
**parameters,
"vertex_cache_attached": vertex_cache_attached,
}
llm_request_json = json.dumps(llm_request_payload)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
try:
response = await router.acompletion(
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
@@ -382,14 +386,16 @@ class LLMAPIHandlerFactory:
)
raise LLMProviderError(llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
llm_response_json = response.model_dump_json(indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
prompt_tokens = 0
completion_tokens = 0
reasoning_tokens = 0
@@ -424,7 +430,7 @@ class LLMAPIHandlerFactory:
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
if step:
if step and not is_speculative_step:
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
@@ -446,28 +452,33 @@ class LLMAPIHandlerFactory:
cached_token_count=cached_tokens if cached_tokens > 0 else None,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
if context and len(context.hashed_href_map) > 0:
llm_content = json.dumps(parsed_response)
rendered_content = Template(llm_content).render(context.hashed_href_map)
parsed_response = json.loads(rendered_content)
parsed_response_json = json.dumps(parsed_response, indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
data=parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
rendered_response_json = None
if context and len(context.hashed_href_map) > 0:
llm_content = json.dumps(parsed_response)
rendered_content = Template(llm_content).render(context.hashed_href_map)
parsed_response = json.loads(rendered_content)
rendered_response_json = json.dumps(parsed_response, indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
# Track LLM API handler duration, token counts, and cost
organization_id = organization_id or (
step.organization_id if step else (thought.organization_id if thought else None)
@@ -489,6 +500,23 @@ class LLMAPIHandlerFactory:
llm_cost=llm_cost if llm_cost > 0 else None,
)
if step and is_speculative_step:
step.speculative_llm_metadata = SpeculativeLLMMetadata(
prompt=llm_prompt_value,
llm_request_json=llm_request_json,
llm_response_json=llm_response_json,
parsed_response_json=parsed_response_json,
rendered_response_json=rendered_response_json,
llm_key=llm_key,
model=main_model_group,
duration_seconds=duration_seconds,
input_tokens=prompt_tokens if prompt_tokens > 0 else None,
output_tokens=completion_tokens if completion_tokens > 0 else None,
reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None,
cached_tokens=cached_tokens if cached_tokens > 0 else None,
llm_cost=llm_cost if llm_cost > 0 else None,
)
return parsed_response
llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined]
@@ -547,7 +575,8 @@ class LLMAPIHandlerFactory:
)
context = skyvern_context.current()
if context and len(context.hashed_href_map) > 0:
is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP,
@@ -557,15 +586,17 @@ class LLMAPIHandlerFactory:
ai_suggestion=ai_suggestion,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
llm_prompt_value = prompt
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_prompt_value.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
if not llm_config.supports_vision:
screenshots = None
@@ -630,22 +661,23 @@ class LLMAPIHandlerFactory:
cache_attached=True,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
"vertex_cache_attached": vertex_cache_attached,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
llm_request_payload = {
"model": model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
"vertex_cache_attached": vertex_cache_attached,
}
llm_request_json = json.dumps(llm_request_payload)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
t_llm_request = time.perf_counter()
try:
@@ -692,14 +724,16 @@ class LLMAPIHandlerFactory:
)
raise LLMProviderError(llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
llm_response_json = response.model_dump_json(indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
prompt_tokens = 0
completion_tokens = 0
@@ -912,7 +946,8 @@ class LLMCaller:
active_parameters.update(self.llm_config.litellm_params) # type: ignore
context = skyvern_context.current()
if context and len(context.hashed_href_map) > 0:
is_speculative_step = step.is_speculative if step else False
if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"),
artifact_type=ArtifactType.HASHED_HREF_MAP,
@@ -939,7 +974,8 @@ class LLMCaller:
tool["display_width_px"] = target_dimension["width"]
screenshots = resize_screenshots(screenshots, target_dimension)
if prompt:
llm_prompt_value = prompt or ""
if prompt and step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
@@ -971,21 +1007,22 @@ class LLMCaller:
screenshots,
message_pattern=message_pattern,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": self.llm_config.model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
llm_request_payload = {
"model": self.llm_config.model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
}
llm_request_json = json.dumps(llm_request_payload)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_request_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
t_llm_request = time.perf_counter()
try:
response = await self._dispatch_llm_call(
@@ -1019,17 +1056,19 @@ class LLMCaller:
LOG.exception("LLM request failed unexpectedly", llm_key=self.llm_key)
raise LLMProviderError(self.llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
llm_response_json = response.model_dump_json(indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=llm_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
call_stats = await self.get_call_stats(response)
if step:
if step and not is_speculative_step:
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
@@ -1051,6 +1090,34 @@ class LLMCaller:
thought_cost=call_stats.llm_cost,
)
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
parsed_response_json = json.dumps(parsed_response, indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=parsed_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
rendered_response_json = None
if context and len(context.hashed_href_map) > 0:
llm_content = json.dumps(parsed_response)
rendered_content = Template(llm_content).render(context.hashed_href_map)
parsed_response = json.loads(rendered_content)
rendered_response_json = json.dumps(parsed_response, indent=2)
if step and not is_speculative_step:
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=rendered_response_json.encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
organization_id = organization_id or (
step.organization_id if step else (thought.organization_id if thought else None)
)
@@ -1071,32 +1138,27 @@ class LLMCaller:
cached_tokens=call_stats.cached_tokens if call_stats and call_stats.cached_tokens else None,
llm_cost=call_stats.llm_cost if call_stats and call_stats.llm_cost else None,
)
if step and is_speculative_step:
step.speculative_llm_metadata = SpeculativeLLMMetadata(
prompt=llm_prompt_value,
llm_request_json=llm_request_json,
llm_response_json=llm_response_json,
parsed_response_json=parsed_response_json,
rendered_response_json=rendered_response_json,
llm_key=self.llm_key,
model=self.llm_config.model_name,
duration_seconds=duration_seconds,
input_tokens=call_stats.input_tokens,
output_tokens=call_stats.output_tokens,
reasoning_tokens=call_stats.reasoning_tokens,
cached_tokens=call_stats.cached_tokens,
llm_cost=call_stats.llm_cost,
)
if raw_response:
return response.model_dump(exclude_none=True)
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
if context and len(context.hashed_href_map) > 0:
llm_content = json.dumps(parsed_response)
rendered_content = Template(llm_content).render(context.hashed_href_map)
parsed_response = json.loads(rendered_content)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_RENDERED,
step=step,
task_v2=task_v2,
thought=thought,
ai_suggestion=ai_suggestion,
)
return parsed_response
def get_screenshot_resize_target_dimension(self, window_dimension: Resolution | None) -> Resolution:

View File

@@ -62,6 +62,7 @@ class SkyvernContext:
# parallel verification optimization
# stores pre-scraped data for next step to avoid re-scraping
next_step_pre_scraped_data: dict[str, Any] | None = None
speculative_plans: dict[str, Any] = field(default_factory=dict)
"""
Example output value:

View File

@@ -39,6 +39,22 @@ class StepStatus(StrEnum):
return self in status_is_terminal
class SpeculativeLLMMetadata(BaseModel):
prompt: str
llm_request_json: str
llm_response_json: str | None = None
parsed_response_json: str | None = None
rendered_response_json: str | None = None
llm_key: str | None = None
model: str | None = None
duration_seconds: float | None = None
input_tokens: int | None = None
output_tokens: int | None = None
reasoning_tokens: int | None = None
cached_tokens: int | None = None
llm_cost: float | None = None
class Step(BaseModel):
created_at: datetime
modified_at: datetime
@@ -55,6 +71,9 @@ class Step(BaseModel):
reasoning_token_count: int | None = None
cached_token_count: int | None = None
step_cost: float = 0
is_speculative: bool = False
speculative_original_status: StepStatus | None = None
speculative_llm_metadata: SpeculativeLLMMetadata | None = None
def validate_update(
self,
@@ -64,7 +83,7 @@ class Step(BaseModel):
) -> None:
old_status = self.status
if status and not old_status.can_update_to(status):
if status and status != old_status and not old_status.can_update_to(status):
raise ValueError(f"invalid_status_transition({old_status},{status},{self.step_id})")
if status == StepStatus.canceled:
@@ -83,6 +102,7 @@ class Step(BaseModel):
old_status not in [StepStatus.running, StepStatus.created]
and self.output is not None
and output is not None
and not (status == old_status == StepStatus.completed)
):
raise ValueError(f"cant_override_output({self.step_id})")