diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index e37e0dc6..05e18d29 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -1,5 +1,6 @@ import asyncio import base64 +import hashlib import json import os import random @@ -2467,8 +2468,35 @@ class ForgeAgent: return scraped_page, extract_action_prompt, use_caching + @staticmethod + def _build_extract_action_cache_variant( + verification_code_check: bool, + has_magic_link_page: bool, + complete_criterion: str | None, + ) -> str: + """ + Build a short-but-unique cache variant identifier so extract-action prompts that + differ meaningfully (OTP, magic link flows, complete criteria) do not reuse the + same Vertex cache object. + """ + variant_parts: list[str] = [] + if verification_code_check: + variant_parts.append("vc") + if has_magic_link_page: + variant_parts.append("ml") + if complete_criterion: + normalized = " ".join(complete_criterion.split()) + digest = hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:6] + variant_parts.append(f"cc{digest}") + return "-".join(variant_parts) if variant_parts else "std" + async def _create_vertex_cache_for_task( - self, task: Task, static_prompt: str, context: SkyvernContext, llm_key_override: str | None + self, + task: Task, + static_prompt: str, + context: SkyvernContext, + llm_key_override: str | None, + prompt_variant: str | None = None, ) -> None: """ Create a Vertex AI cache for the task's static prompt. @@ -2479,9 +2507,9 @@ class ForgeAgent: task: The task to create cache for static_prompt: The static prompt content to cache context: The Skyvern context to store the cache name in + llm_key_override: Optional override when we explicitly pick an LLM key + prompt_variant: Cache variant identifier (std/vc/ml/etc.) """ - # Early return if task doesn't have an llm_key - # This should not happen given the guard at the call site, but being defensive resolved_llm_key = llm_key_override or task.llm_key if not resolved_llm_key: @@ -2491,17 +2519,20 @@ class ForgeAgent: ) return + cache_variant = prompt_variant or "std" + try: LOG.info( "Attempting Vertex AI cache creation", task_id=task.task_id, llm_key=resolved_llm_key, + cache_variant=cache_variant, ) cache_manager = get_cache_manager() - # Use llm_key as cache_key so all tasks with the same model share the same cache - # This maximizes cache reuse and reduces cache storage costs - cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}-{resolved_llm_key}" + variant_suffix = f"-{cache_variant}" if cache_variant else "" + + cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}{variant_suffix}-{resolved_llm_key}" # Get the actual model name from LLM config to ensure correct format # (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash") @@ -2565,8 +2596,10 @@ class ForgeAgent: ttl_seconds=3600, # 1 hour ) - # Store cache resource name in context + # Store cache metadata in context context.vertex_cache_name = cache_data["name"] + context.vertex_cache_key = cache_key + context.vertex_cache_variant = cache_variant LOG.info( "Created Vertex AI cache for task", @@ -2574,6 +2607,7 @@ class ForgeAgent: cache_key=cache_key, cache_name=cache_data["name"], model_name=model_name, + cache_variant=cache_variant, ) except Exception as e: LOG.warning( @@ -2653,7 +2687,7 @@ class ForgeAgent: # Check if prompt caching is enabled for extract-action use_caching = False - prompt_caching_settings = LLMAPIHandlerFactory._prompt_caching_settings or {} + prompt_caching_settings = await self._get_prompt_caching_settings(context) effective_llm_key = task.llm_key if not effective_llm_key: handler_for_key = LLMAPIHandlerFactory.get_override_llm_api_handler( @@ -2701,6 +2735,11 @@ class ForgeAgent: "parse_select_feature_enabled": context.enable_parse_select_in_extract, "has_magic_link_page": context.has_magic_link_page(task.task_id), } + cache_variant = self._build_extract_action_cache_variant( + verification_code_check=verification_code_check, + has_magic_link_page=context.has_magic_link_page(task.task_id), + complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None, + ) static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs) dynamic_prompt = prompt_engine.load_prompt( f"{template}-dynamic", @@ -2718,7 +2757,13 @@ class ForgeAgent: # Create Vertex AI cache for Gemini models if effective_llm_key and "GEMINI" in effective_llm_key: - await self._create_vertex_cache_for_task(task, static_prompt, context, effective_llm_key) + await self._create_vertex_cache_for_task( + task, + static_prompt, + context, + effective_llm_key, + prompt_variant=cache_variant, + ) combined_prompt = f"{static_prompt.rstrip()}\n\n{dynamic_prompt.lstrip()}" @@ -2726,6 +2771,7 @@ class ForgeAgent: "Using cached prompt", task_id=task.task_id, prompt_name=EXTRACT_ACTION_PROMPT_NAME, + cache_variant=cache_variant, ) return combined_prompt, use_caching @@ -2755,6 +2801,55 @@ class ForgeAgent: return full_prompt, use_caching + async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]: + """ + Resolve prompt caching settings for the current run. + + We prefer explicit overrides via LLMAPIHandlerFactory.set_prompt_caching_settings(), which + are mostly used by scripts/tests. When no override exists, evaluate the PostHog experiment + once per context and cache the result on the context to avoid repeated lookups. + """ + if LLMAPIHandlerFactory._prompt_caching_settings is not None: + return LLMAPIHandlerFactory._prompt_caching_settings + + if context.prompt_caching_settings is not None: + return context.prompt_caching_settings + + distinct_id = context.run_id or context.workflow_run_id or context.task_id + organization_id = context.organization_id + context.prompt_caching_settings = {} + + if not distinct_id or not organization_id: + return context.prompt_caching_settings + + try: + enabled = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached( + "PROMPT_CACHING_OPTIMIZATION", + distinct_id, + properties={"organization_id": organization_id}, + ) + except Exception as exc: + LOG.warning( + "Failed to evaluate prompt caching experiment; defaulting to disabled", + distinct_id=distinct_id, + organization_id=organization_id, + error=str(exc), + ) + return context.prompt_caching_settings + + if enabled: + context.prompt_caching_settings = { + EXTRACT_ACTION_PROMPT_NAME: True, + EXTRACT_ACTION_TEMPLATE: True, + } + LOG.info( + "Prompt caching optimization enabled", + distinct_id=distinct_id, + organization_id=organization_id, + ) + + return context.prompt_caching_settings + def _should_process_totp(self, scraped_page: ScrapedPage | None) -> bool: """Detect TOTP pages by checking for multiple input fields or verification keywords.""" if not scraped_page: diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 042e6e53..002a0ce7 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -1,3 +1,4 @@ +import copy import dataclasses import json import time @@ -29,6 +30,7 @@ from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context +from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.models import SpeculativeLLMMetadata, Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought @@ -38,6 +40,7 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension, LOG = structlog.get_logger() EXTRACT_ACTION_PROMPT_NAME = "extract-actions" +CHECK_USER_GOAL_PROMPT_NAMES = {"check-user-goal", "check-user-goal-with-termination"} @runtime_checkable @@ -61,6 +64,44 @@ class LLMCallStats(BaseModel): llm_cost: float | None = None +async def _log_hashed_href_map_artifacts_if_needed( + context: SkyvernContext | None, + step: Step | None, + task_v2: TaskV2 | None, + thought: Thought | None, + ai_suggestion: AISuggestion | None, + *, + is_speculative_step: bool, +) -> None: + if context and context.hashed_href_map and step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), + artifact_type=ArtifactType.HASHED_HREF_MAP, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + + +def _log_vertex_cache_hit_if_needed( + context: SkyvernContext | None, + prompt_name: str, + llm_identifier: str, + cached_tokens: int, +) -> None: + if cached_tokens > 0 and prompt_name == EXTRACT_ACTION_PROMPT_NAME and context and context.vertex_cache_name: + LOG.info( + "Vertex cache hit", + prompt_name=prompt_name, + llm_key=llm_identifier, + cached_tokens=cached_tokens, + cache_name=context.vertex_cache_name, + cache_key=context.vertex_cache_key, + cache_variant=context.vertex_cache_variant, + ) + + class LLMAPIHandlerFactory: _custom_handlers: dict[str, LLMAPIHandler] = {} _thinking_budget_settings: dict[str, int] | None = None @@ -237,6 +278,7 @@ class LLMAPIHandlerFactory: if not override_llm_key: return default try: + # Explicit overrides should honor the exact model choice and skip experimentation reroutes. return LLMAPIHandlerFactory.get_llm_api_handler(override_llm_key) except Exception: LOG.warning( @@ -320,17 +362,17 @@ class LLMAPIHandlerFactory: context = skyvern_context.current() is_speculative_step = step.is_speculative if step else False - if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), - artifact_type=ArtifactType.HASHED_HREF_MAP, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + await _log_hashed_href_map_artifacts_if_needed( + context, + step, + task_v2, + thought, + ai_suggestion, + is_speculative_step=is_speculative_step, + ) llm_prompt_value = prompt + if step and not is_speculative_step: await app.ARTIFACT_MANAGER.create_llm_artifact( data=llm_prompt_value.encode("utf-8"), @@ -343,6 +385,25 @@ class LLMAPIHandlerFactory: # Build messages and apply caching in one step messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) + async def _log_llm_request_artifact(model_label: str, vertex_cache_attached_flag: bool) -> str: + llm_request_payload = { + "model": model_label, + "messages": messages, + **parameters, + "vertex_cache_attached": vertex_cache_attached_flag, + } + llm_request_json = json.dumps(llm_request_payload) + if step and not is_speculative_step: + await app.ARTIFACT_MANAGER.create_llm_artifact( + data=llm_request_json.encode("utf-8"), + artifact_type=ArtifactType.LLM_REQUEST, + step=step, + task_v2=task_v2, + thought=thought, + ai_suggestion=ai_suggestion, + ) + return llm_request_json + # Inject context caching system message when available try: context_cached_static_prompt = getattr(context, "cached_static_prompt", None) @@ -377,70 +438,96 @@ class LLMAPIHandlerFactory: except Exception as e: LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True) - vertex_cache_attached = False cache_resource_name = getattr(context, "vertex_cache_name", None) + cache_variant = getattr(context, "vertex_cache_variant", None) primary_model_dict = _get_primary_model_dict(router, main_model_group) - - # Add cached_content to primary model's litellm_params (not global parameters) - # This ensures it's only passed to the Gemini primary, not to fallback models. - # By setting it in the model-specific litellm_params, LiteLLM will only include it - # when calling the primary model. When falling back to GPT-5, the fallback model's - # litellm_params won't have cached_content, so it won't be sent. - if ( - cache_resource_name + should_attach_vertex_cache = bool( + cache_resource_name is not None and prompt_name == EXTRACT_ACTION_PROMPT_NAME and getattr(context, "use_prompt_caching", False) + and main_model_group and "gemini" in main_model_group.lower() and primary_model_dict is not None - ): - litellm_params = primary_model_dict.setdefault("litellm_params", {}) - litellm_params["cached_content"] = cache_resource_name - vertex_cache_attached = True + ) + + model_used = main_model_group + llm_request_json = "" + + async def _call_primary_with_vertex_cache( + cache_name: str, + cache_variant_name: str | None, + ) -> tuple[ModelResponse, str, str]: + if primary_model_dict is None: + raise ValueError("Primary router model missing configuration") + litellm_params = copy.deepcopy(primary_model_dict.get("litellm_params") or {}) + if not litellm_params: + raise ValueError("Primary router model missing litellm_params") + active_params = copy.deepcopy(litellm_params) + active_params.update(parameters) + active_params["cached_content"] = cache_name + request_model = active_params.pop("model", primary_model_dict.get("model_name", main_model_group)) LOG.info( - "Adding Vertex AI cache reference to primary model in router", + "Adding Vertex AI cache reference to primary Gemini request", prompt_name=prompt_name, primary_model=main_model_group, fallback_model=llm_config.fallback_model_group, + cache_name=cache_name, + cache_key=getattr(context, "vertex_cache_key", None), + cache_variant=cache_variant_name, ) - elif primary_model_dict and "litellm_params" in primary_model_dict: - if primary_model_dict["litellm_params"].pop("cached_content", None): - LOG.info( - "Removed Vertex AI cache reference from primary model in router", - prompt_name=prompt_name, - primary_model=main_model_group, - ) + request_payload_json = await _log_llm_request_artifact(request_model, True) + response = await litellm.acompletion( + model=request_model, + messages=messages, + timeout=settings.LLM_CONFIG_TIMEOUT, + drop_params=True, + **active_params, + ) + return response, request_model, request_payload_json - llm_request_payload = { - "model": llm_key, - "messages": messages, - **parameters, - "vertex_cache_attached": vertex_cache_attached, - } - llm_request_json = json.dumps(llm_request_payload) - if step and not is_speculative_step: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=llm_request_json.encode("utf-8"), - artifact_type=ArtifactType.LLM_REQUEST, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) - model_used = main_model_group - try: + async def _call_router_without_cache() -> tuple[ModelResponse, str]: + request_payload_json = await _log_llm_request_artifact(llm_key, False) response = await router.acompletion( - model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters + model=main_model_group, + messages=messages, + timeout=settings.LLM_CONFIG_TIMEOUT, + **parameters, ) - response_model = response.model or main_model_group - model_used = response_model - if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group): - LOG.info( - "LLM router fallback succeeded", - llm_key=llm_key, - prompt_name=prompt_name, - primary_model=main_model_group, - fallback_model=response_model, - ) + return response, request_payload_json + + try: + response: ModelResponse | None = None + if should_attach_vertex_cache and cache_resource_name: + try: + response, direct_model_used, llm_request_json = await _call_primary_with_vertex_cache( + cache_resource_name, + cache_variant, + ) + model_used = response.model or direct_model_used + except CancelledError: + raise + except Exception as cache_error: + LOG.warning( + "Vertex cache primary call failed, retrying via router", + prompt_name=prompt_name, + error=str(cache_error), + cache_name=cache_resource_name, + cache_variant=cache_variant, + ) + response = None + + if response is None: + response, llm_request_json = await _call_router_without_cache() + response_model = response.model or main_model_group + model_used = response_model + if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group): + LOG.info( + "LLM router fallback succeeded", + llm_key=llm_key, + prompt_name=prompt_name, + primary_model=main_model_group, + fallback_model=response_model, + ) except litellm.exceptions.APIError as e: raise LLMProviderErrorRetryableTask(llm_key) from e except litellm.exceptions.ContextWindowExceededError as e: @@ -611,7 +698,10 @@ class LLMAPIHandlerFactory: return llm_api_handler_with_router_and_fallback @staticmethod - def get_llm_api_handler(llm_key: str, base_parameters: dict[str, Any] | None = None) -> LLMAPIHandler: + def get_llm_api_handler( + llm_key: str, + base_parameters: dict[str, Any] | None = None, + ) -> LLMAPIHandler: try: llm_config = LLMConfigRegistry.get_config(llm_key) except InvalidLLMConfigError: @@ -668,15 +758,14 @@ class LLMAPIHandlerFactory: context = skyvern_context.current() is_speculative_step = step.is_speculative if step else False - if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), - artifact_type=ArtifactType.HASHED_HREF_MAP, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + await _log_hashed_href_map_artifacts_if_needed( + context, + step, + task_v2, + thought, + ai_suggestion, + is_speculative_step=is_speculative_step, + ) llm_prompt_value = prompt if step and not is_speculative_step: @@ -746,6 +835,9 @@ class LLMAPIHandlerFactory: "Adding Vertex AI cache reference to request", prompt_name=prompt_name, cache_attached=True, + cache_name=cache_resource_name, + cache_key=getattr(context, "vertex_cache_key", None), + cache_variant=getattr(context, "vertex_cache_variant", None), ) elif "cached_content" in active_parameters: removed_cache = active_parameters.pop("cached_content", None) @@ -754,6 +846,9 @@ class LLMAPIHandlerFactory: "Removed Vertex AI cache reference from request", prompt_name=prompt_name, cache_was_attached=True, + cache_name=cache_resource_name, + cache_key=getattr(context, "vertex_cache_key", None), + cache_variant=getattr(context, "vertex_cache_variant", None), ) llm_request_payload = { @@ -863,6 +958,8 @@ class LLMAPIHandlerFactory: if cached_tokens == 0: cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0 + _log_vertex_cache_hit_if_needed(context, prompt_name, model_name, cached_tokens) + if step: await app.DATABASE.update_step( task_id=step.task_id, @@ -1041,15 +1138,14 @@ class LLMCaller: context = skyvern_context.current() is_speculative_step = step.is_speculative if step else False - if context and len(context.hashed_href_map) > 0 and step and not is_speculative_step: - await app.ARTIFACT_MANAGER.create_llm_artifact( - data=json.dumps(context.hashed_href_map, indent=2).encode("utf-8"), - artifact_type=ArtifactType.HASHED_HREF_MAP, - step=step, - task_v2=task_v2, - thought=thought, - ai_suggestion=ai_suggestion, - ) + await _log_hashed_href_map_artifacts_if_needed( + context, + step, + task_v2, + thought, + ai_suggestion, + is_speculative_step=is_speculative_step, + ) if screenshots and self.screenshot_scaling_enabled: target_dimension = self.get_screenshot_resize_target_dimension(window_dimension) diff --git a/skyvern/forge/sdk/core/skyvern_context.py b/skyvern/forge/sdk/core/skyvern_context.py index 40256ded..8f3edc47 100644 --- a/skyvern/forge/sdk/core/skyvern_context.py +++ b/skyvern/forge/sdk/core/skyvern_context.py @@ -37,6 +37,9 @@ class SkyvernContext: use_prompt_caching: bool = False cached_static_prompt: str | None = None vertex_cache_name: str | None = None # Vertex AI cache resource name for explicit caching + vertex_cache_key: str | None = None # Logical cache key (includes variant + llm key) + vertex_cache_variant: str | None = None # Variant identifier used when creating the cache + prompt_caching_settings: dict[str, bool] | None = None enable_speed_optimizations: bool = False # script run context diff --git a/skyvern/forge/sdk/trace/experiment_utils.py b/skyvern/forge/sdk/trace/experiment_utils.py index d4ad6d61..5b6ad72d 100644 --- a/skyvern/forge/sdk/trace/experiment_utils.py +++ b/skyvern/forge/sdk/trace/experiment_utils.py @@ -66,7 +66,7 @@ async def collect_experiment_metadata( "LLM_NAME", "LLM_SECONDARY_NAME", # Add more experiment flags as needed - "PROMPT_CACHING_ENABLED", + "PROMPT_CACHING_OPTIMIZATION", "THINKING_BUDGET_OPTIMIZATION", ]