From 44528cbd380348e6a72490b0b4208ef3d6466271 Mon Sep 17 00:00:00 2001 From: pedrohsdb Date: Thu, 6 Nov 2025 14:47:58 -0800 Subject: [PATCH] Pedro/fix explicit caching vertex api (#3933) --- skyvern/forge/agent.py | 114 ++++++++++ .../forge/sdk/api/llm/api_handler_factory.py | 62 ++---- .../forge/sdk/api/llm/vertex_cache_manager.py | 203 ++++++++++++++++++ skyvern/forge/sdk/core/skyvern_context.py | 1 + 4 files changed, 340 insertions(+), 40 deletions(-) create mode 100644 skyvern/forge/sdk/api/llm/vertex_cache_manager.py diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 34a9795b..fd35a557 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -3,6 +3,7 @@ import base64 import json import os import random +import re import string from asyncio.exceptions import CancelledError from datetime import UTC, datetime @@ -71,11 +72,14 @@ from skyvern.forge.sdk.api.files import ( wait_for_download_finished, ) from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager +from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry from skyvern.forge.sdk.api.llm.exceptions import LLM_PROVIDER_ERROR_RETRYABLE_TASK_TYPE, LLM_PROVIDER_ERROR_TYPE from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller +from skyvern.forge.sdk.api.llm.vertex_cache_manager import get_cache_manager from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature +from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs from skyvern.forge.sdk.models import Step, StepStatus @@ -2149,6 +2153,108 @@ class ForgeAgent: return scraped_page, extract_action_prompt, use_caching + async def _create_vertex_cache_for_task(self, task: Task, static_prompt: str, context: SkyvernContext) -> None: + """ + Create a Vertex AI cache for the task's static prompt. + + Uses llm_key as cache key to enable cache sharing across tasks with the same model. + + Args: + task: The task to create cache for + static_prompt: The static prompt content to cache + context: The Skyvern context to store the cache name in + """ + # Early return if task doesn't have an llm_key + # This should not happen given the guard at the call site, but being defensive + if not task.llm_key: + LOG.warning( + "Cannot create Vertex AI cache without llm_key, skipping cache creation", + task_id=task.task_id, + ) + return + + try: + cache_manager = get_cache_manager() + + # Use llm_key as cache_key so all tasks with the same model share the same cache + # This maximizes cache reuse and reduces cache storage costs + cache_key = f"extract-action-static-{task.llm_key}" + + # Get the actual model name from LLM config to ensure correct format + # (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash") + model_name = "gemini-2.5-flash" # Default + + try: + llm_config = LLMConfigRegistry.get_config(task.llm_key) + extracted_name = None + + # Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-" + if hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str): + if "vertex_ai/" in llm_config.model_name: + # Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash" + extracted_name = llm_config.model_name.split("/")[-1] + elif llm_config.model_name.startswith("gemini-"): + # Already in correct format + extracted_name = llm_config.model_name + + # For router/fallback configs, extract from api_base or infer from key name + if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params: + params = llm_config.litellm_params + api_base = getattr(params, "api_base", None) + if api_base and isinstance(api_base, str) and "/models/" in api_base: + # Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash" + extracted_name = api_base.split("/models/")[-1] + + # For router configs without api_base, infer from the llm_key itself + if not extracted_name: + # Extract version from llm_key (e.g., VERTEX_GEMINI_1_5_FLASH -> "1_5" or VERTEX_GEMINI_2.5_FLASH -> "2.5") + # Pattern: GEMINI_{version}_{flavor} where version can use dots, underscores, or dashes + version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", task.llm_key, re.IGNORECASE) + version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5" + + # Determine flavor + if "_PRO_" in task.llm_key or task.llm_key.endswith("_PRO"): + extracted_name = f"gemini-{version}-pro" + elif "_FLASH_LITE_" in task.llm_key or task.llm_key.endswith("_FLASH_LITE"): + extracted_name = f"gemini-{version}-flash-lite" + else: + # Default to flash flavor + extracted_name = f"gemini-{version}-flash" + + if extracted_name: + model_name = extracted_name + except Exception as e: + LOG.debug("Failed to extract model name from config, using default", error=str(e)) + + # Create cache for this task + # Use asyncio.to_thread to offload blocking HTTP request (requests.post) + # This prevents freezing the event loop during cache creation + cache_data = await asyncio.to_thread( + cache_manager.create_cache, + model_name=model_name, + static_content=static_prompt, + cache_key=cache_key, + ttl_seconds=3600, # 1 hour + ) + + # Store cache resource name in context + context.vertex_cache_name = cache_data["name"] + + LOG.info( + "Created Vertex AI cache for task", + task_id=task.task_id, + cache_key=cache_key, + cache_name=cache_data["name"], + model_name=model_name, + ) + except Exception as e: + LOG.warning( + "Failed to create Vertex AI cache, proceeding without caching", + task_id=task.task_id, + error=str(e), + exc_info=True, + ) + async def _build_extract_action_prompt( self, task: Task, @@ -2243,6 +2349,11 @@ class ForgeAgent: # Store static prompt for caching and return dynamic prompt context.cached_static_prompt = static_prompt use_caching = True + + # Create Vertex AI cache for Gemini models + if task.llm_key and "GEMINI" in task.llm_key: + await self._create_vertex_cache_for_task(task, static_prompt, context) + LOG.info("Using cached prompt for extract-action", task_id=task.task_id) return dynamic_prompt, use_caching @@ -2524,6 +2635,9 @@ class ForgeAgent: raise TaskNotFound(task_id=task.task_id) from e task = refreshed_task + # Caches expire based on TTL (1 hour) or can be cleaned up via scheduled job + # This allows multiple tasks with the same llm_key to share the same cache + # log the task status as an event analytics.capture("skyvern-oss-agent-task-status", {"status": task.status}) diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 027777aa..8860a2e2 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -287,27 +287,8 @@ class LLMAPIHandlerFactory: and isinstance(llm_config, LLMConfig) and isinstance(llm_config.model_name, str) ): - # Check if this is a Vertex AI model - if "vertex_ai/" in llm_config.model_name: - caching_system_message = { - "role": "system", - "content": [ - { - "type": "text", - "text": context_cached_static_prompt, - "cache_control": {"type": "ephemeral", "ttl": "3600s"}, - } - ], - } - messages = [caching_system_message] + messages - LOG.info( - "Applied Vertex context caching", - prompt_name=prompt_name, - model=llm_config.model_name, - ttl_seconds=3600, - ) # Check if this is an OpenAI model - elif ( + if ( llm_config.model_name.startswith("gpt-") or llm_config.model_name.startswith("o1-") or llm_config.model_name.startswith("o3-") @@ -582,27 +563,8 @@ class LLMAPIHandlerFactory: and isinstance(llm_config, LLMConfig) and isinstance(llm_config.model_name, str) ): - # Check if this is a Vertex AI model - if "vertex_ai/" in llm_config.model_name: - caching_system_message = { - "role": "system", - "content": [ - { - "type": "text", - "text": context_cached_static_prompt, - "cache_control": {"type": "ephemeral", "ttl": "3600s"}, - } - ], - } - messages = [caching_system_message] + messages - LOG.info( - "Applied Vertex context caching", - prompt_name=prompt_name, - model=llm_config.model_name, - ttl_seconds=3600, - ) # Check if this is an OpenAI model - elif ( + if ( llm_config.model_name.startswith("gpt-") or llm_config.model_name.startswith("o1-") or llm_config.model_name.startswith("o3-") @@ -626,6 +588,24 @@ class LLMAPIHandlerFactory: ) except Exception as e: LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True) + + # Add Vertex AI cache reference only for the intended cached prompt + vertex_cache_attached = False + cache_resource_name = getattr(context, "vertex_cache_name", None) + if ( + cache_resource_name + and "vertex_ai/" in model_name + and prompt_name == "extract-actions" + and getattr(context, "use_prompt_caching", False) + ): + active_parameters["cached_content"] = cache_resource_name + vertex_cache_attached = True + LOG.info( + "Adding Vertex AI cache reference to request", + prompt_name=prompt_name, + cache_attached=True, + ) + await app.ARTIFACT_MANAGER.create_llm_artifact( data=json.dumps( { @@ -633,6 +613,7 @@ class LLMAPIHandlerFactory: "messages": messages, # we're not using active_parameters here because it may contain sensitive information **parameters, + "vertex_cache_attached": vertex_cache_attached, } ).encode("utf-8"), artifact_type=ArtifactType.LLM_REQUEST, @@ -641,6 +622,7 @@ class LLMAPIHandlerFactory: thought=thought, ai_suggestion=ai_suggestion, ) + t_llm_request = time.perf_counter() try: # TODO (kerem): add a timeout to this call diff --git a/skyvern/forge/sdk/api/llm/vertex_cache_manager.py b/skyvern/forge/sdk/api/llm/vertex_cache_manager.py new file mode 100644 index 00000000..5f97c3b1 --- /dev/null +++ b/skyvern/forge/sdk/api/llm/vertex_cache_manager.py @@ -0,0 +1,203 @@ +""" +Vertex AI Context Caching Manager. + +This module implements the CORRECT caching pattern for Vertex AI using the /cachedContents API. +Unlike the Anthropic-style cache_control markers, Vertex AI requires: +1. Creating a cache object via POST to /cachedContents +2. Getting the cache resource name +3. Referencing that cache name in subsequent requests +""" + +from datetime import datetime +from typing import Any + +import google.auth +import requests +import structlog +from google.auth.transport.requests import Request + +from skyvern.config import settings + +LOG = structlog.get_logger() + + +class VertexCacheManager: + """ + Manages Vertex AI context caching using the correct /cachedContents API. + + This provides guaranteed cache hits for static content across requests, + unlike implicit caching which requires exact prompt matches. + """ + + def __init__(self, project_id: str, location: str = "global"): + self.project_id = project_id + self.location = location + # Use regional endpoint for non-global locations, global endpoint for global + if location == "global": + self.api_endpoint = "aiplatform.googleapis.com" + else: + self.api_endpoint = f"{location}-aiplatform.googleapis.com" + self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data + + def _get_access_token(self) -> str: + """Get Google Cloud access token for API calls.""" + try: + # Try to use default credentials + credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + credentials.refresh(Request()) + return credentials.token + except Exception as e: + LOG.error("Failed to get access token", error=str(e)) + raise + + def create_cache( + self, + model_name: str, + static_content: str, + cache_key: str, + ttl_seconds: int = 3600, + system_instruction: str | None = None, + ) -> dict[str, Any]: + """ + Create a cache object using Vertex AI's /cachedContents API. + + Args: + model_name: Full model path (e.g., "gemini-2.5-flash") + static_content: The static content to cache + cache_key: Unique key to identify this cache (e.g., f"task_{task_id}") + ttl_seconds: Time to live in seconds (default: 1 hour) + system_instruction: Optional system instruction to include + + Returns: + Cache data with 'name', 'expireTime', etc. + """ + # Check if cache already exists for this key + if cache_key in self._cache_registry: + cache_data = self._cache_registry[cache_key] + # Check if still valid + expire_time = datetime.fromisoformat(cache_data["expireTime"].replace("Z", "+00:00")) + if expire_time > datetime.now(expire_time.tzinfo): + LOG.info("Reusing existing cache", cache_key=cache_key, cache_name=cache_data["name"]) + return cache_data + else: + LOG.info("Cache expired, creating new one", cache_key=cache_key) + # Clean up expired cache + try: + self.delete_cache(cache_key) + except Exception: + pass # Best effort cleanup + + url = f"https://{self.api_endpoint}/v1/projects/{self.project_id}/locations/{self.location}/cachedContents" + + # Build the model path + full_model_path = f"projects/{self.project_id}/locations/{self.location}/publishers/google/models/{model_name}" + + # Create payload + payload: dict[str, Any] = { + "model": full_model_path, + "contents": [{"role": "user", "parts": [{"text": static_content}]}], + "ttl": f"{ttl_seconds}s", + } + + # Add system instruction if provided + if system_instruction: + payload["systemInstruction"] = {"parts": [{"text": system_instruction}]} + + headers = {"Authorization": f"Bearer {self._get_access_token()}", "Content-Type": "application/json"} + + LOG.info( + "Creating Vertex AI cache object", + cache_key=cache_key, + model=model_name, + content_size=len(static_content), + ttl_seconds=ttl_seconds, + ) + + try: + response = requests.post(url, headers=headers, json=payload, timeout=30) + + if response.status_code != 200: + LOG.error( + "Failed to create cache", + cache_key=cache_key, + status_code=response.status_code, + response=response.text, + ) + raise Exception(f"Cache creation failed: {response.text}") + + cache_data = response.json() + cache_name = cache_data["name"] + + # Store in registry + self._cache_registry[cache_key] = cache_data + + LOG.info( + "Cache created successfully", + cache_key=cache_key, + cache_name=cache_name, + expires_at=cache_data.get("expireTime"), + ) + + return cache_data + + except requests.exceptions.Timeout: + LOG.error("Cache creation timed out", cache_key=cache_key) + raise + except Exception as e: + LOG.error("Cache creation failed", cache_key=cache_key, error=str(e)) + raise + + def delete_cache(self, cache_key: str) -> bool: + """Delete a cache object.""" + cache_data = self._cache_registry.get(cache_key) + if not cache_data: + LOG.warning("Cache not found in registry", cache_key=cache_key) + return False + + cache_name = cache_data["name"] + url = f"https://{self.api_endpoint}/v1/{cache_name}" + + headers = { + "Authorization": f"Bearer {self._get_access_token()}", + } + + LOG.info("Deleting cache", cache_key=cache_key, cache_name=cache_name) + + try: + response = requests.delete(url, headers=headers, timeout=10) + + if response.status_code in (200, 204): + # Remove from registry + del self._cache_registry[cache_key] + LOG.info("Cache deleted successfully", cache_key=cache_key) + return True + else: + LOG.warning( + "Failed to delete cache", + cache_key=cache_key, + status_code=response.status_code, + response=response.text, + ) + return False + except Exception as e: + LOG.error("Cache deletion failed", cache_key=cache_key, error=str(e)) + return False + + +# Global cache manager instance +_global_cache_manager: VertexCacheManager | None = None + + +def get_cache_manager() -> VertexCacheManager: + """Get or create the global cache manager instance.""" + global _global_cache_manager + + if _global_cache_manager is None: + project_id = settings.VERTEX_PROJECT_ID or "skyvern-production" + # Default to "global" to match the model configs in cloud/__init__.py + # Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching) + location = settings.VERTEX_LOCATION or "global" + _global_cache_manager = VertexCacheManager(project_id, location) + LOG.info("Created global cache manager", project_id=project_id, location=location) + + return _global_cache_manager diff --git a/skyvern/forge/sdk/core/skyvern_context.py b/skyvern/forge/sdk/core/skyvern_context.py index b318a2ae..637813cf 100644 --- a/skyvern/forge/sdk/core/skyvern_context.py +++ b/skyvern/forge/sdk/core/skyvern_context.py @@ -36,6 +36,7 @@ class SkyvernContext: enable_parse_select_in_extract: bool = False use_prompt_caching: bool = False cached_static_prompt: str | None = None + vertex_cache_name: str | None = None # Vertex AI cache resource name for explicit caching # script run context script_id: str | None = None