From b7e28b075cfcaf1eb153c97f8807ff21a7dce367 Mon Sep 17 00:00:00 2001 From: pedrohsdb Date: Thu, 13 Nov 2025 17:18:32 -0800 Subject: [PATCH] parallelize goal check within task (#3997) --- skyvern/exceptions.py | 5 + skyvern/forge/agent.py | 651 ++++++++++++------ .../forge/sdk/api/llm/api_handler_factory.py | 326 +++++---- skyvern/forge/sdk/core/skyvern_context.py | 1 + skyvern/forge/sdk/models.py | 22 +- 5 files changed, 675 insertions(+), 330 deletions(-) diff --git a/skyvern/exceptions.py b/skyvern/exceptions.py index 8addc366..bc22ea7d 100644 --- a/skyvern/exceptions.py +++ b/skyvern/exceptions.py @@ -82,6 +82,11 @@ class MissingElement(SkyvernException): ) +class MissingExtractActionsResponse(SkyvernException): + def __init__(self) -> None: + super().__init__("extract-actions response missing") + + class MultipleElementsFound(SkyvernException): def __init__(self, num: int, selector: str | None = None, element_id: str | None = None): super().__init__( diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 5850d2ec..e78d22ce 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -6,6 +6,7 @@ import random import re import string from asyncio.exceptions import CancelledError +from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path from typing import Any, Tuple, cast @@ -48,6 +49,7 @@ from skyvern.exceptions import ( InvalidTaskStatusTransition, InvalidWorkflowTaskURLState, MissingBrowserStatePage, + MissingExtractActionsResponse, NoTOTPVerificationCodeFound, ScrapingFailed, SkyvernException, @@ -81,7 +83,7 @@ from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs -from skyvern.forge.sdk.models import Step, StepStatus +from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step, StepStatus from skyvern.forge.sdk.schemas.files import FileInfo from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, TaskStatus @@ -136,6 +138,15 @@ EXTRACT_ACTION_PROMPT_NAME = "extract-actions" EXTRACT_ACTION_CACHE_KEY_PREFIX = f"{EXTRACT_ACTION_TEMPLATE}-static" +@dataclass +class SpeculativePlan: + scraped_page: ScrapedPage + extract_action_prompt: str + use_caching: bool + llm_json_response: dict[str, Any] | None + llm_metadata: SpeculativeLLMMetadata | None = None + + class ActionLinkedNode: def __init__(self, action: Action) -> None: self.action = action @@ -915,19 +926,35 @@ class ForgeAgent: organization=organization, task=task, step=step, browser_state=browser_state ) - ( - scraped_page, - extract_action_prompt, - use_caching, - ) = await self.build_and_record_step_prompt( - task, - step, - browser_state, - engine, - ) + speculative_plan: SpeculativePlan | None = None + reuse_speculative_llm_response = False + speculative_llm_metadata: SpeculativeLLMMetadata | None = None + if context: + speculative_plan = context.speculative_plans.pop(step.step_id, None) + + if speculative_plan: + step.is_speculative = False + scraped_page = speculative_plan.scraped_page + extract_action_prompt = speculative_plan.extract_action_prompt + use_caching = speculative_plan.use_caching + json_response = speculative_plan.llm_json_response + reuse_speculative_llm_response = json_response is not None + speculative_llm_metadata = speculative_plan.llm_metadata + else: + ( + scraped_page, + extract_action_prompt, + use_caching, + ) = await self.build_and_record_step_prompt( + task, + step, + browser_state, + engine, + ) + json_response = None + detailed_agent_step_output.scraped_page = scraped_page detailed_agent_step_output.extract_action_prompt = extract_action_prompt - json_response = None actions: list[Action] if engine == RunEngine.openai_cua: @@ -986,12 +1013,20 @@ class ForgeAgent: if context: context.use_prompt_caching = True - json_response = await llm_api_handler( - prompt=extract_action_prompt, - prompt_name="extract-actions", - step=step, - screenshots=scraped_page.screenshots, - ) + if not reuse_speculative_llm_response: + json_response = await llm_api_handler( + prompt=extract_action_prompt, + prompt_name="extract-actions", + step=step, + screenshots=scraped_page.screenshots, + ) + else: + LOG.debug( + "Using speculative extract-actions response", + step_id=step.step_id, + ) + if json_response is None: + raise MissingExtractActionsResponse() try: otp_json_response, otp_actions = await self.handle_potential_OTP_actions( task, step, scraped_page, browser_state, json_response @@ -1035,6 +1070,14 @@ class ForgeAgent: ) ] + if reuse_speculative_llm_response and speculative_llm_metadata: + await self._persist_speculative_llm_metadata( + step, + speculative_llm_metadata, + screenshots=scraped_page.screenshots, + ) + speculative_llm_metadata = None + detailed_agent_step_output.actions = actions if len(actions) == 0: LOG.info( @@ -1308,6 +1351,7 @@ class ForgeAgent: break task_completes_on_download = task_block and task_block.complete_on_download and task.workflow_run_id + enable_parallel_verification = False if ( not has_decisive_action and not task_completes_on_download @@ -1385,6 +1429,8 @@ class ForgeAgent: status=StepStatus.completed, output=detailed_agent_step_output.to_agent_step_output(), ) + if enable_parallel_verification: + completed_step.speculative_original_status = StepStatus.completed return completed_step, detailed_agent_step_output.get_clean_detailed_output() except CancelledError: LOG.exception( @@ -1748,51 +1794,229 @@ class ForgeAgent: return draw_boxes - async def _pre_scrape_for_next_step( + async def _speculate_next_step_plan( self, task: Task, - step: Step, + current_step: Step, + next_step: Step, browser_state: BrowserState, engine: RunEngine, - ) -> ScrapedPage | None: - """ - Pre-scrape the page for the next step while verification is running. - This is the expensive operation (5-10 seconds) that we want to run in parallel. - """ - try: - max_screenshot_number = settings.MAX_NUM_SCREENSHOTS - draw_boxes = True - scroll = True - if engine in CUA_ENGINES: - max_screenshot_number = 1 - draw_boxes = False - scroll = False - - # Check PostHog feature flag to skip screenshot annotations - draw_boxes = await self._should_skip_screenshot_annotations(task, draw_boxes) - - scraped_page = await scrape_website( - browser_state, - task.url, - app.AGENT_FUNCTION.cleanup_element_tree_factory(task=task, step=step), - scrape_exclude=app.scrape_exclude, - max_screenshot_number=max_screenshot_number, - draw_boxes=draw_boxes, - scroll=scroll, - ) + ) -> SpeculativePlan | None: + if engine in CUA_ENGINES: LOG.info( - "Pre-scraped page for next step in parallel with verification", - step_id=step.step_id, - num_elements=len(scraped_page.elements) if scraped_page else 0, + "Skipping speculative extract-actions for CUA engine", + step_id=current_step.step_id, + task_id=task.task_id, + ) + return None + + try: + next_step.is_speculative = True + + scraped_page, extract_action_prompt, use_caching = await self.build_and_record_step_prompt( + task, + next_step, + browser_state, + engine, + persist_artifacts=False, + ) + + llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler( + task.llm_key, + default=app.LLM_API_HANDLER, + ) + + llm_json_response = await llm_api_handler( + prompt=extract_action_prompt, + prompt_name="extract-actions", + step=next_step, + screenshots=scraped_page.screenshots, + ) + + LOG.info( + "Speculative extract-actions completed", + current_step_id=current_step.step_id, + synthetic_step_id=next_step.step_id, + ) + + metadata_copy = None + if next_step.speculative_llm_metadata is not None: + metadata_copy = next_step.speculative_llm_metadata.model_copy() + next_step.speculative_llm_metadata = None + next_step.is_speculative = False + + return SpeculativePlan( + scraped_page=scraped_page, + extract_action_prompt=extract_action_prompt, + use_caching=use_caching, + llm_json_response=llm_json_response, + llm_metadata=metadata_copy, ) - return scraped_page except Exception: LOG.warning( - "Failed to pre-scrape for next step, will re-scrape on next step execution", + "Failed to run speculative extract-actions", + step_id=current_step.step_id, + exc_info=True, + ) + next_step.is_speculative = False + return None + + async def _persist_speculative_llm_metadata( + self, + step: Step, + metadata: SpeculativeLLMMetadata, + *, + screenshots: list[bytes] | None = None, + ) -> None: + if not metadata: + return + + LOG.debug("Persisting speculative LLM metadata") + + if metadata.prompt: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=metadata.prompt.encode("utf-8"), + artifact_type=ArtifactType.LLM_PROMPT, + screenshots=screenshots, + step=step, + ) + + if metadata.llm_request_json: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=metadata.llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + step=step, + ) + + if metadata.llm_response_json: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=metadata.llm_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + step=step, + ) + + if metadata.parsed_response_json: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=metadata.parsed_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + step=step, + ) + + if metadata.rendered_response_json: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=metadata.rendered_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + step=step, + ) + + incremental_cost = metadata.llm_cost if metadata.llm_cost and metadata.llm_cost > 0 else None + incremental_input_tokens = ( + metadata.input_tokens if metadata.input_tokens and metadata.input_tokens > 0 else None + ) + incremental_output_tokens = ( + metadata.output_tokens if metadata.output_tokens and metadata.output_tokens > 0 else None + ) + incremental_reasoning_tokens = ( + metadata.reasoning_tokens if metadata.reasoning_tokens and metadata.reasoning_tokens > 0 else None + ) + incremental_cached_tokens = ( + metadata.cached_tokens if metadata.cached_tokens and metadata.cached_tokens > 0 else None + ) + + if ( + incremental_cost is not None + or incremental_input_tokens is not None + or incremental_output_tokens is not None + or incremental_reasoning_tokens is not None + or incremental_cached_tokens is not None + ): + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=incremental_cost, + incremental_input_tokens=incremental_input_tokens, + incremental_output_tokens=incremental_output_tokens, + incremental_reasoning_tokens=incremental_reasoning_tokens, + incremental_cached_tokens=incremental_cached_tokens, + ) + + if incremental_input_tokens: + step.input_token_count += incremental_input_tokens + if incremental_output_tokens: + step.output_token_count += incremental_output_tokens + if incremental_reasoning_tokens: + step.reasoning_token_count = (step.reasoning_token_count or 0) + incremental_reasoning_tokens + if incremental_cached_tokens: + step.cached_token_count = (step.cached_token_count or 0) + incremental_cached_tokens + if incremental_cost: + step.step_cost += incremental_cost + + step.speculative_llm_metadata = None + + async def _persist_speculative_metadata_for_discarded_plan( + self, + step: Step, + speculative_task: asyncio.Future[SpeculativePlan | None], + *, + cancel_step: bool = False, + ) -> None: + try: + plan = await asyncio.shield(speculative_task) + except CancelledError: + LOG.debug( + "Speculative extract-actions cancelled before metadata persistence", + step_id=step.step_id, + ) + step.is_speculative = False + if cancel_step: + await self._cancel_speculative_step(step) + return + except Exception: + LOG.debug( + "Speculative extract-actions failed before metadata persistence", + step_id=step.step_id, + exc_info=True, + ) + step.is_speculative = False + if cancel_step: + await self._cancel_speculative_step(step) + return + + if not plan or not plan.llm_metadata: + step.is_speculative = False + if cancel_step: + await self._cancel_speculative_step(step) + return + + try: + await self._persist_speculative_llm_metadata( + step, + plan.llm_metadata, + ) + step.is_speculative = False + if cancel_step: + await self._cancel_speculative_step(step) + except Exception: + LOG.warning( + "Failed to persist speculative llm metadata for discarded plan", + step_id=step.step_id, + exc_info=True, + ) + + async def _cancel_speculative_step(self, step: Step) -> None: + if step.status == StepStatus.canceled: + return + try: + updated_step = await self.update_step(step, status=StepStatus.canceled) + step.status = updated_step.status + step.is_speculative = False + except Exception: + LOG.warning( + "Failed to cancel speculative step", step_id=step.step_id, exc_info=True, ) - return None async def complete_verify( self, page: Page, scraped_page: ScrapedPage, task: Task, step: Step, task_block: BaseTaskBlock | None = None @@ -2099,6 +2323,8 @@ class ForgeAgent: step: Step, browser_state: BrowserState, engine: RunEngine, + *, + persist_artifacts: bool = True, ) -> tuple[ScrapedPage, str, bool]: # Check if we have pre-scraped data from parallel verification optimization context = skyvern_context.current() @@ -2178,11 +2404,12 @@ class ForgeAgent: extract_action_prompt = "" use_caching = False - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.HTML_SCRAPE, - data=scraped_page.html.encode(), - ) + if persist_artifacts: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.HTML_SCRAPE, + data=scraped_page.html.encode(), + ) LOG.info( "Scraped website", step_order=step.order, @@ -2191,6 +2418,7 @@ class ForgeAgent: url=task.url, ) # TODO: we only use HTML element for now, introduce a way to switch in the future + enable_speed_optimizations = getattr(context, "enable_speed_optimizations", False) element_tree_format = ElementTreeFormat.HTML # OPTIMIZATION: Use economy tree (skip SVGs) when ENABLE_SPEED_OPTIMIZATIONS is enabled @@ -2248,31 +2476,32 @@ class ForgeAgent: expire_verification_code=True, ) - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_CSS_MAP, - data=json.dumps(scraped_page.id_to_css_dict, indent=2).encode(), - ) - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_FRAME_MAP, - data=json.dumps(scraped_page.id_to_frame_dict, indent=2).encode(), - ) - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE, - data=json.dumps(scraped_page.element_tree, indent=2).encode(), - ) - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED, - data=json.dumps(scraped_page.element_tree_trimmed, indent=2).encode(), - ) - await app.ARTIFACT_MANAGER.create_artifact( - step=step, - artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_IN_PROMPT, - data=element_tree_in_prompt.encode(), - ) + if persist_artifacts: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_CSS_MAP, + data=json.dumps(scraped_page.id_to_css_dict, indent=2).encode(), + ) + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.VISIBLE_ELEMENTS_ID_FRAME_MAP, + data=json.dumps(scraped_page.id_to_frame_dict, indent=2).encode(), + ) + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE, + data=json.dumps(scraped_page.element_tree, indent=2).encode(), + ) + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_TRIMMED, + data=json.dumps(scraped_page.element_tree_trimmed, indent=2).encode(), + ) + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.VISIBLE_ELEMENTS_TREE_IN_PROMPT, + data=element_tree_in_prompt.encode(), + ) return scraped_page, extract_action_prompt, use_caching @@ -2480,6 +2709,16 @@ class ForgeAgent: task_llm_key=task.llm_key, effective_llm_key=effective_llm_key, ) + enable_speed_optimizations = context.enable_speed_optimizations + element_tree_format = ElementTreeFormat.HTML + if enable_speed_optimizations: + if step.retry_index == 0: + elements_for_prompt = scraped_page.build_economy_elements_tree(element_tree_format) + else: + elements_for_prompt = scraped_page.build_element_tree(element_tree_format) + else: + elements_for_prompt = scraped_page.build_element_tree(element_tree_format) + if template == EXTRACT_ACTION_TEMPLATE and cache_enabled: try: # Try to load split templates for caching @@ -2501,7 +2740,11 @@ class ForgeAgent: "has_magic_link_page": context.has_magic_link_page(task.task_id), } static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs) - dynamic_prompt = prompt_engine.load_prompt(f"{template}-dynamic", **prompt_kwargs) + dynamic_prompt = prompt_engine.load_prompt( + f"{template}-dynamic", + elements=elements_for_prompt, + **prompt_kwargs, + ) # Store static prompt for caching and continue sending it alongside the dynamic section. # Vertex explicit caching expects the static content to still be present in the request so the @@ -3250,12 +3493,11 @@ class ForgeAgent: the standard flow would have called check_user_goal_complete in agent_step). """ LOG.info( - "Starting parallel user goal verification optimization", + "Starting parallel user goal verification with speculative extract-actions", step_id=step.step_id, task_id=task.task_id, ) - # Task 1: Verify user goal (typically 2-5 seconds) verification_task = asyncio.create_task( self.check_user_goal_complete( page=page, @@ -3267,18 +3509,31 @@ class ForgeAgent: name=f"verify_goal_{step.step_id}", ) - # Task 2: Pre-scrape for next step (typically 5-10 seconds) - pre_scrape_task = asyncio.create_task( - self._pre_scrape_for_next_step( + next_step = await app.DATABASE.create_step( + task_id=task.task_id, + order=step.order + 1, + retry_index=0, + organization_id=task.organization_id, + ) + + LOG.debug( + "Waiting before launching speculative plan", + step_id=step.step_id, + task_id=task.task_id, + ) + await asyncio.sleep(1.0) + + speculative_task = asyncio.create_task( + self._speculate_next_step_plan( task=task, - step=step, + current_step=step, + next_step=next_step, browser_state=browser_state, engine=engine, ), - name=f"pre_scrape_{step.step_id}", + name=f"speculate_next_step_{step.step_id}", ) - # Wait for verification to complete first (faster of the two) try: complete_action = await verification_task except Exception: @@ -3290,25 +3545,15 @@ class ForgeAgent: complete_action = None if complete_action is not None: - # Goal achieved or should terminate! Cancel the pre-scraping task - is_terminate = isinstance(complete_action, TerminateAction) - LOG.info( - "Parallel verification: goal achieved or termination required, cancelling pre-scraping", - step_id=step.step_id, - task_id=task.task_id, - is_terminate=is_terminate, + asyncio.create_task( + self._persist_speculative_metadata_for_discarded_plan( + next_step, + speculative_task, + cancel_step=True, + ) ) - pre_scrape_task.cancel() - try: - await pre_scrape_task # Clean up the cancelled task - except asyncio.CancelledError: - LOG.debug("Pre-scraping cancelled successfully", step_id=step.step_id) - except Exception: - LOG.debug("Pre-scraping task cleanup failed", step_id=step.step_id, exc_info=True) - working_page = page - if working_page is None: - working_page = await browser_state.must_get_working_page() + working_page = page or await browser_state.must_get_working_page() if step.output is None: step.output = AgentStepOutput(action_results=[], actions_and_results=[], errors=[]) @@ -3333,21 +3578,27 @@ class ForgeAgent: if isinstance(persisted_action, DecisiveAction) and persisted_action.errors: step.output.errors.extend(persisted_action.errors) - if is_terminate: - # Mark task as terminated/failed - # Note: This requires the USE_TERMINATION_AWARE_COMPLETE_VERIFICATION experiment to be enabled + if isinstance(persisted_action, TerminateAction): LOG.warning( - "Parallel verification: termination required, marking task as terminated (termination-aware experiment)", + "Parallel verification: termination required, marking task as terminated", step_id=step.step_id, task_id=task.task_id, reasoning=complete_action.reasoning, ) - last_step = await self.update_step(step, output=step.output, is_last=True) + final_status = step.speculative_original_status or StepStatus.completed + step.speculative_original_status = None + step.status = final_status + last_step = await self.update_step( + step, + status=final_status, + output=step.output, + is_last=True, + ) task_errors = None - if isinstance(persisted_action, TerminateAction) and persisted_action.errors: + if persisted_action.errors: task_errors = [error.model_dump() for error in persisted_action.errors] failure_reason = persisted_action.reasoning - if isinstance(persisted_action, TerminateAction) and persisted_action.errors: + if persisted_action.errors: failure_reason = "; ".join(error.reasoning for error in persisted_action.errors) await self.update_task( task, @@ -3356,102 +3607,108 @@ class ForgeAgent: errors=task_errors, ) return True, last_step, None - else: - # Mark task as complete - # Note: Step is already marked as completed by agent_step - # We don't add the complete action to the step output since the step is already finalized - LOG.info( - "Parallel verification: goal achieved, marking task as complete", - step_id=step.step_id, - task_id=task.task_id, - ) - last_step = await self.update_step(step, output=step.output, is_last=True) - extracted_information = await self.get_extracted_information_for_task(task) - await self.update_task( - task, - status=TaskStatus.completed, - extracted_information=extracted_information, - ) - return True, last_step, None - else: - # Goal not achieved - wait for pre-scraping to complete + LOG.info( - "Parallel verification: goal not achieved, using pre-scraped data for next step", + "Parallel verification: goal achieved, marking task as complete", step_id=step.step_id, task_id=task.task_id, ) + final_status = step.speculative_original_status or StepStatus.completed + step.speculative_original_status = None + step.status = final_status + last_step = await self.update_step( + step, + status=final_status, + output=step.output, + is_last=True, + ) + extracted_information = await self.get_extracted_information_for_task(task) + await self.update_task( + task, + status=TaskStatus.completed, + extracted_information=extracted_information, + ) + return True, last_step, None - try: - pre_scraped_page = await pre_scrape_task - except Exception: - LOG.warning( - "Pre-scraping failed, next step will re-scrape", - step_id=step.step_id, - exc_info=True, - ) - pre_scraped_page = None + LOG.info( + "Parallel verification: goal not achieved, awaiting speculative extract-actions", + step_id=step.step_id, + task_id=task.task_id, + ) - # Check max steps before creating next step - context = skyvern_context.current() - override_max_steps_per_run = context.max_steps_override if context else None - max_steps_per_run = ( - override_max_steps_per_run - or task.max_steps_per_run - or organization.max_steps_per_run - or settings.MAX_STEPS_PER_RUN + try: + speculative_plan = await speculative_task + except CancelledError: + LOG.debug("Speculative extract-actions cancelled after verification finished", step_id=step.step_id) + speculative_plan = None + except Exception: + LOG.warning( + "Speculative extract-actions failed, next step will run sequentially", + step_id=step.step_id, + exc_info=True, + ) + speculative_plan = None + + context = skyvern_context.current() + override_max_steps_per_run = context.max_steps_override if context else None + max_steps_per_run = ( + override_max_steps_per_run + or task.max_steps_per_run + or organization.max_steps_per_run + or settings.MAX_STEPS_PER_RUN + ) + + if step.order + 1 >= max_steps_per_run: + LOG.info( + "Step completed but max steps reached, marking task as failed", + step_order=step.order, + step_retry=step.retry_index, + max_steps=max_steps_per_run, + ) + final_status = step.speculative_original_status or StepStatus.completed + step.speculative_original_status = None + step.status = final_status + last_step = await self.update_step( + step, + status=final_status, + output=step.output, + is_last=True, ) - if step.order + 1 >= max_steps_per_run: - LOG.info( - "Step completed but max steps reached, marking task as failed", - step_order=step.order, - step_retry=step.retry_index, - max_steps=max_steps_per_run, - ) - last_step = await self.update_step(step, is_last=True) + generated_failure_reason = await self.summary_failure_reason_for_max_steps( + organization=organization, + task=task, + step=step, + page=page, + ) + failure_reason = f"Reached the maximum steps ({max_steps_per_run}). Possible failure reasons: {generated_failure_reason.reasoning}" + errors = [ReachMaxStepsError().model_dump()] + [ + error.model_dump() for error in generated_failure_reason.errors + ] - generated_failure_reason = await self.summary_failure_reason_for_max_steps( - organization=organization, - task=task, - step=step, - page=page, - ) - failure_reason = f"Reached the maximum steps ({max_steps_per_run}). Possible failure reasons: {generated_failure_reason.reasoning}" - errors = [ReachMaxStepsError().model_dump()] + [ - error.model_dump() for error in generated_failure_reason.errors - ] + await self._cancel_speculative_step(next_step) - await self.update_task( - task, - status=TaskStatus.failed, - failure_reason=failure_reason, - errors=errors, - ) - return False, last_step, None + await self.update_task( + task, + status=TaskStatus.failed, + failure_reason=failure_reason, + errors=errors, + ) + return False, last_step, None - # Create next step - next_step = await app.DATABASE.create_step( - task_id=task.task_id, - order=step.order + 1, - retry_index=0, - organization_id=task.organization_id, + if speculative_plan: + context = skyvern_context.ensure_context() + context.speculative_plans[next_step.step_id] = speculative_plan + LOG.info( + "Stored speculative extract-actions plan for next step", + current_step_id=step.step_id, + next_step_id=next_step.step_id, ) - # Store pre-scraped data in context for next step to use - if pre_scraped_page: - context = skyvern_context.ensure_context() - context.next_step_pre_scraped_data = { - "step_id": next_step.step_id, - "scraped_page": pre_scraped_page, - "timestamp": datetime.now(UTC), - } - LOG.info( - "Stored pre-scraped data for next step", - step_id=next_step.step_id, - num_elements=len(pre_scraped_page.elements), - ) + step.status = step.speculative_original_status or StepStatus.completed + step.speculative_original_status = None - return None, None, next_step + return None, None, next_step async def handle_failed_step(self, organization: Organization, task: Task, step: Step) -> Step | None: max_retries_per_step = ( diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 0b5fe5bb..d29b3c67 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -29,7 +29,7 @@ from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context -from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought from skyvern.forge.sdk.trace import TraceManager @@ -260,7 +260,8 @@ class LLMAPIHandlerFactory: ) context = skyvern_context.current() - if context and len(context.hashed_href_map) > 0: + is_speculative_step = step.is_speculative if step else False + if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step: await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), artifact_type=ArtifactType.HASHED_HREF_MAP, @@ -270,14 +271,16 @@ class LLMAPIHandlerFactory: ai_suggestion=ai_suggestion, ) - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=prompt.encode("utf-8"), - artifact_type=ArtifactType.LLM_PROMPT, - screenshots=screenshots, - step=step, - task_v2=task_v2, - thought=thought, - ) + llm_prompt_value = prompt + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_prompt_value.encode("utf-8"), + artifact_type=ArtifactType.LLM_PROMPT, + screenshots=screenshots, + step=step, + task_v2=task_v2, + thought=thought, + ) # Build messages and apply caching in one step messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) @@ -330,21 +333,22 @@ class LLMAPIHandlerFactory: cache_attached=True, ) - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps( - { - "model": llm_key, - "messages": messages, - **parameters, - "vertex_cache_attached": vertex_cache_attached, - } - ).encode("utf-8"), - artifact_type=ArtifactType.LLM_REQUEST, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + llm_request_payload = { + "model": llm_key, + "messages": messages, + **parameters, + "vertex_cache_attached": vertex_cache_attached, + } + llm_request_json = json.dumps(llm_request_payload) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) try: response = await router.acompletion( model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters @@ -382,14 +386,16 @@ class LLMAPIHandlerFactory: ) raise LLMProviderError(llm_key) from e - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=response.model_dump_json(indent=2).encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + llm_response_json = response.model_dump_json(indent=2) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) prompt_tokens = 0 completion_tokens = 0 reasoning_tokens = 0 @@ -424,7 +430,7 @@ class LLMAPIHandlerFactory: # Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage if cached_tokens == 0: cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0 - if step: + if step and not is_speculative_step: await app.DATABASE.update_step( task_id=step.task_id, step_id=step.step_id, @@ -446,28 +452,33 @@ class LLMAPIHandlerFactory: cached_token_count=cached_tokens if cached_tokens > 0 else None, ) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps(parsed_response, indent=2).encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_PARSED, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) - - if context and len(context.hashed_href_map) > 0: - llm_content = json.dumps(parsed_response) - rendered_content = Template(llm_content).render(context.hashed_href_map) - parsed_response = json.loads(rendered_content) + parsed_response_json = json.dumps(parsed_response, indent=2) + if step and not is_speculative_step: await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps(parsed_response, indent=2).encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + data=parsed_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, step=step, task_v2=task_v2, thought=thought, ai_suggestion=ai_suggestion, ) + rendered_response_json = None + if context and len(context.hashed_href_map) > 0: + llm_content = json.dumps(parsed_response) + rendered_content = Template(llm_content).render(context.hashed_href_map) + parsed_response = json.loads(rendered_content) + rendered_response_json = json.dumps(parsed_response, indent=2) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=rendered_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + # Track LLM API handler duration, token counts, and cost organization_id = organization_id or ( step.organization_id if step else (thought.organization_id if thought else None) @@ -489,6 +500,23 @@ class LLMAPIHandlerFactory: llm_cost=llm_cost if llm_cost > 0 else None, ) + if step and is_speculative_step: + step.speculative_llm_metadata = SpeculativeLLMMetadata( + prompt=llm_prompt_value, + llm_request_json=llm_request_json, + llm_response_json=llm_response_json, + parsed_response_json=parsed_response_json, + rendered_response_json=rendered_response_json, + llm_key=llm_key, + model=main_model_group, + duration_seconds=duration_seconds, + input_tokens=prompt_tokens if prompt_tokens > 0 else None, + output_tokens=completion_tokens if completion_tokens > 0 else None, + reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, + cached_tokens=cached_tokens if cached_tokens > 0 else None, + llm_cost=llm_cost if llm_cost > 0 else None, + ) + return parsed_response llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined] @@ -547,7 +575,8 @@ class LLMAPIHandlerFactory: ) context = skyvern_context.current() - if context and len(context.hashed_href_map) > 0: + is_speculative_step = step.is_speculative if step else False + if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step: await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), artifact_type=ArtifactType.HASHED_HREF_MAP, @@ -557,15 +586,17 @@ class LLMAPIHandlerFactory: ai_suggestion=ai_suggestion, ) - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=prompt.encode("utf-8"), - artifact_type=ArtifactType.LLM_PROMPT, - screenshots=screenshots, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + llm_prompt_value = prompt + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_prompt_value.encode("utf-8"), + artifact_type=ArtifactType.LLM_PROMPT, + screenshots=screenshots, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) if not llm_config.supports_vision: screenshots = None @@ -630,22 +661,23 @@ class LLMAPIHandlerFactory: cache_attached=True, ) - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps( - { - "model": model_name, - "messages": messages, - # we're not using active_parameters here because it may contain sensitive information - **parameters, - "vertex_cache_attached": vertex_cache_attached, - } - ).encode("utf-8"), - artifact_type=ArtifactType.LLM_REQUEST, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + llm_request_payload = { + "model": model_name, + "messages": messages, + # we're not using active_parameters here because it may contain sensitive information + **parameters, + "vertex_cache_attached": vertex_cache_attached, + } + llm_request_json = json.dumps(llm_request_payload) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) t_llm_request = time.perf_counter() try: @@ -692,14 +724,16 @@ class LLMAPIHandlerFactory: ) raise LLMProviderError(llm_key) from e - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=response.model_dump_json(indent=2).encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + llm_response_json = response.model_dump_json(indent=2) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) prompt_tokens = 0 completion_tokens = 0 @@ -912,7 +946,8 @@ class LLMCaller: active_parameters.update(self.llm_config.litellm_params) # type: ignore context = skyvern_context.current() - if context and len(context.hashed_href_map) > 0: + is_speculative_step = step.is_speculative if step else False + if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step: await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), artifact_type=ArtifactType.HASHED_HREF_MAP, @@ -939,7 +974,8 @@ class LLMCaller: tool["display_width_px"] = target_dimension["width"] screenshots = resize_screenshots(screenshots, target_dimension) - if prompt: + llm_prompt_value = prompt or "" + if prompt and step and not is_speculative_step: await app.ARTIFACT_MANAGER.create_llm_artifact( data=prompt.encode("utf-8"), artifact_type=ArtifactType.LLM_PROMPT, @@ -971,21 +1007,22 @@ class LLMCaller: screenshots, message_pattern=message_pattern, ) - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps( - { - "model": self.llm_config.model_name, - "messages": messages, - # we're not using active_parameters here because it may contain sensitive information - **parameters, - } - ).encode("utf-8"), - artifact_type=ArtifactType.LLM_REQUEST, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + llm_request_payload = { + "model": self.llm_config.model_name, + "messages": messages, + # we're not using active_parameters here because it may contain sensitive information + **parameters, + } + llm_request_json = json.dumps(llm_request_payload) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) t_llm_request = time.perf_counter() try: response = await self._dispatch_llm_call( @@ -1019,17 +1056,19 @@ class LLMCaller: LOG.exception("LLM request failed unexpectedly", llm_key=self.llm_key) raise LLMProviderError(self.llm_key) from e - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=response.model_dump_json(indent=2).encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + llm_response_json = response.model_dump_json(indent=2) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) call_stats = await self.get_call_stats(response) - if step: + if step and not is_speculative_step: await app.DATABASE.update_step( task_id=step.task_id, step_id=step.step_id, @@ -1051,6 +1090,34 @@ class LLMCaller: thought_cost=call_stats.llm_cost, ) + parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix) + parsed_response_json = json.dumps(parsed_response, indent=2) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=parsed_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + + rendered_response_json = None + if context and len(context.hashed_href_map) > 0: + llm_content = json.dumps(parsed_response) + rendered_content = Template(llm_content).render(context.hashed_href_map) + parsed_response = json.loads(rendered_content) + rendered_response_json = json.dumps(parsed_response, indent=2) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=rendered_response_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + organization_id = organization_id or ( step.organization_id if step else (thought.organization_id if thought else None) ) @@ -1071,32 +1138,27 @@ class LLMCaller: cached_tokens=call_stats.cached_tokens if call_stats and call_stats.cached_tokens else None, llm_cost=call_stats.llm_cost if call_stats and call_stats.llm_cost else None, ) + + if step and is_speculative_step: + step.speculative_llm_metadata = SpeculativeLLMMetadata( + prompt=llm_prompt_value, + llm_request_json=llm_request_json, + llm_response_json=llm_response_json, + parsed_response_json=parsed_response_json, + rendered_response_json=rendered_response_json, + llm_key=self.llm_key, + model=self.llm_config.model_name, + duration_seconds=duration_seconds, + input_tokens=call_stats.input_tokens, + output_tokens=call_stats.output_tokens, + reasoning_tokens=call_stats.reasoning_tokens, + cached_tokens=call_stats.cached_tokens, + llm_cost=call_stats.llm_cost, + ) + if raw_response: return response.model_dump(exclude_none=True) - parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix) - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps(parsed_response, indent=2).encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_PARSED, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) - - if context and len(context.hashed_href_map) > 0: - llm_content = json.dumps(parsed_response) - rendered_content = Template(llm_content).render(context.hashed_href_map) - parsed_response = json.loads(rendered_content) - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps(parsed_response, indent=2).encode("utf-8"), - artifact_type=ArtifactType.LLM_RESPONSE_RENDERED, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) - return parsed_response def get_screenshot_resize_target_dimension(self, window_dimension: Resolution | None) -> Resolution: diff --git a/skyvern/forge/sdk/core/skyvern_context.py b/skyvern/forge/sdk/core/skyvern_context.py index 6a4ec557..40256ded 100644 --- a/skyvern/forge/sdk/core/skyvern_context.py +++ b/skyvern/forge/sdk/core/skyvern_context.py @@ -62,6 +62,7 @@ class SkyvernContext: # parallel verification optimization # stores pre-scraped data for next step to avoid re-scraping next_step_pre_scraped_data: dict[str, Any] | None = None + speculative_plans: dict[str, Any] = field(default_factory=dict) """ Example output value: diff --git a/skyvern/forge/sdk/models.py b/skyvern/forge/sdk/models.py index e792c779..de0730ed 100644 --- a/skyvern/forge/sdk/models.py +++ b/skyvern/forge/sdk/models.py @@ -39,6 +39,22 @@ class StepStatus(StrEnum): return self in status_is_terminal +class SpeculativeLLMMetadata(BaseModel): + prompt: str + llm_request_json: str + llm_response_json: str | None = None + parsed_response_json: str | None = None + rendered_response_json: str | None = None + llm_key: str | None = None + model: str | None = None + duration_seconds: float | None = None + input_tokens: int | None = None + output_tokens: int | None = None + reasoning_tokens: int | None = None + cached_tokens: int | None = None + llm_cost: float | None = None + + class Step(BaseModel): created_at: datetime modified_at: datetime @@ -55,6 +71,9 @@ class Step(BaseModel): reasoning_token_count: int | None = None cached_token_count: int | None = None step_cost: float = 0 + is_speculative: bool = False + speculative_original_status: StepStatus | None = None + speculative_llm_metadata: SpeculativeLLMMetadata | None = None def validate_update( self, @@ -64,7 +83,7 @@ class Step(BaseModel): ) -> None: old_status = self.status - if status and not old_status.can_update_to(status): + if status and status != old_status and not old_status.can_update_to(status): raise ValueError(f"invalid_status_transition({old_status},{status},{self.step_id})") if status == StepStatus.canceled: @@ -83,6 +102,7 @@ class Step(BaseModel): old_status not in [StepStatus.running, StepStatus.created] and self.output is not None and output is not None + and not (status == old_status == StepStatus.completed) ): raise ValueError(f"cant_override_output({self.step_id})")