Pedro/fix explicit caching vertex api (#3933)

This commit is contained in:
pedrohsdb
2025-11-06 14:47:58 -08:00
committed by GitHub
parent d2f4e27940
commit 44528cbd38
4 changed files with 340 additions and 40 deletions

View File

@@ -3,6 +3,7 @@ import base64
import json
import os
import random
import re
import string
from asyncio.exceptions import CancelledError
from datetime import UTC, datetime
@@ -71,11 +72,14 @@ from skyvern.forge.sdk.api.files import (
wait_for_download_finished,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.exceptions import LLM_PROVIDER_ERROR_RETRYABLE_TASK_TYPE, LLM_PROVIDER_ERROR_TYPE
from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller
from skyvern.forge.sdk.api.llm.vertex_cache_manager import get_cache_manager
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
from skyvern.forge.sdk.models import Step, StepStatus
@@ -2149,6 +2153,108 @@ class ForgeAgent:
return scraped_page, extract_action_prompt, use_caching
async def _create_vertex_cache_for_task(self, task: Task, static_prompt: str, context: SkyvernContext) -> None:
"""
Create a Vertex AI cache for the task's static prompt.
Uses llm_key as cache key to enable cache sharing across tasks with the same model.
Args:
task: The task to create cache for
static_prompt: The static prompt content to cache
context: The Skyvern context to store the cache name in
"""
# Early return if task doesn't have an llm_key
# This should not happen given the guard at the call site, but being defensive
if not task.llm_key:
LOG.warning(
"Cannot create Vertex AI cache without llm_key, skipping cache creation",
task_id=task.task_id,
)
return
try:
cache_manager = get_cache_manager()
# Use llm_key as cache_key so all tasks with the same model share the same cache
# This maximizes cache reuse and reduces cache storage costs
cache_key = f"extract-action-static-{task.llm_key}"
# Get the actual model name from LLM config to ensure correct format
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
model_name = "gemini-2.5-flash" # Default
try:
llm_config = LLMConfigRegistry.get_config(task.llm_key)
extracted_name = None
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
if hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str):
if "vertex_ai/" in llm_config.model_name:
# Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash"
extracted_name = llm_config.model_name.split("/")[-1]
elif llm_config.model_name.startswith("gemini-"):
# Already in correct format
extracted_name = llm_config.model_name
# For router/fallback configs, extract from api_base or infer from key name
if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params:
params = llm_config.litellm_params
api_base = getattr(params, "api_base", None)
if api_base and isinstance(api_base, str) and "/models/" in api_base:
# Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash"
extracted_name = api_base.split("/models/")[-1]
# For router configs without api_base, infer from the llm_key itself
if not extracted_name:
# Extract version from llm_key (e.g., VERTEX_GEMINI_1_5_FLASH -> "1_5" or VERTEX_GEMINI_2.5_FLASH -> "2.5")
# Pattern: GEMINI_{version}_{flavor} where version can use dots, underscores, or dashes
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", task.llm_key, re.IGNORECASE)
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
# Determine flavor
if "_PRO_" in task.llm_key or task.llm_key.endswith("_PRO"):
extracted_name = f"gemini-{version}-pro"
elif "_FLASH_LITE_" in task.llm_key or task.llm_key.endswith("_FLASH_LITE"):
extracted_name = f"gemini-{version}-flash-lite"
else:
# Default to flash flavor
extracted_name = f"gemini-{version}-flash"
if extracted_name:
model_name = extracted_name
except Exception as e:
LOG.debug("Failed to extract model name from config, using default", error=str(e))
# Create cache for this task
# Use asyncio.to_thread to offload blocking HTTP request (requests.post)
# This prevents freezing the event loop during cache creation
cache_data = await asyncio.to_thread(
cache_manager.create_cache,
model_name=model_name,
static_content=static_prompt,
cache_key=cache_key,
ttl_seconds=3600, # 1 hour
)
# Store cache resource name in context
context.vertex_cache_name = cache_data["name"]
LOG.info(
"Created Vertex AI cache for task",
task_id=task.task_id,
cache_key=cache_key,
cache_name=cache_data["name"],
model_name=model_name,
)
except Exception as e:
LOG.warning(
"Failed to create Vertex AI cache, proceeding without caching",
task_id=task.task_id,
error=str(e),
exc_info=True,
)
async def _build_extract_action_prompt(
self,
task: Task,
@@ -2243,6 +2349,11 @@ class ForgeAgent:
# Store static prompt for caching and return dynamic prompt
context.cached_static_prompt = static_prompt
use_caching = True
# Create Vertex AI cache for Gemini models
if task.llm_key and "GEMINI" in task.llm_key:
await self._create_vertex_cache_for_task(task, static_prompt, context)
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
return dynamic_prompt, use_caching
@@ -2524,6 +2635,9 @@ class ForgeAgent:
raise TaskNotFound(task_id=task.task_id) from e
task = refreshed_task
# Caches expire based on TTL (1 hour) or can be cleaned up via scheduled job
# This allows multiple tasks with the same llm_key to share the same cache
# log the task status as an event
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})

View File

@@ -287,27 +287,8 @@ class LLMAPIHandlerFactory:
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
# Check if this is a Vertex AI model
if "vertex_ai/" in llm_config.model_name:
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied Vertex context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
ttl_seconds=3600,
)
# Check if this is an OpenAI model
elif (
if (
llm_config.model_name.startswith("gpt-")
or llm_config.model_name.startswith("o1-")
or llm_config.model_name.startswith("o3-")
@@ -582,27 +563,8 @@ class LLMAPIHandlerFactory:
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
# Check if this is a Vertex AI model
if "vertex_ai/" in llm_config.model_name:
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied Vertex context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
ttl_seconds=3600,
)
# Check if this is an OpenAI model
elif (
if (
llm_config.model_name.startswith("gpt-")
or llm_config.model_name.startswith("o1-")
or llm_config.model_name.startswith("o3-")
@@ -626,6 +588,24 @@ class LLMAPIHandlerFactory:
)
except Exception as e:
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
# Add Vertex AI cache reference only for the intended cached prompt
vertex_cache_attached = False
cache_resource_name = getattr(context, "vertex_cache_name", None)
if (
cache_resource_name
and "vertex_ai/" in model_name
and prompt_name == "extract-actions"
and getattr(context, "use_prompt_caching", False)
):
active_parameters["cached_content"] = cache_resource_name
vertex_cache_attached = True
LOG.info(
"Adding Vertex AI cache reference to request",
prompt_name=prompt_name,
cache_attached=True,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
@@ -633,6 +613,7 @@ class LLMAPIHandlerFactory:
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
"vertex_cache_attached": vertex_cache_attached,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
@@ -641,6 +622,7 @@ class LLMAPIHandlerFactory:
thought=thought,
ai_suggestion=ai_suggestion,
)
t_llm_request = time.perf_counter()
try:
# TODO (kerem): add a timeout to this call

View File

@@ -0,0 +1,203 @@
"""
Vertex AI Context Caching Manager.
This module implements the CORRECT caching pattern for Vertex AI using the /cachedContents API.
Unlike the Anthropic-style cache_control markers, Vertex AI requires:
1. Creating a cache object via POST to /cachedContents
2. Getting the cache resource name
3. Referencing that cache name in subsequent requests
"""
from datetime import datetime
from typing import Any
import google.auth
import requests
import structlog
from google.auth.transport.requests import Request
from skyvern.config import settings
LOG = structlog.get_logger()
class VertexCacheManager:
"""
Manages Vertex AI context caching using the correct /cachedContents API.
This provides guaranteed cache hits for static content across requests,
unlike implicit caching which requires exact prompt matches.
"""
def __init__(self, project_id: str, location: str = "global"):
self.project_id = project_id
self.location = location
# Use regional endpoint for non-global locations, global endpoint for global
if location == "global":
self.api_endpoint = "aiplatform.googleapis.com"
else:
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
def _get_access_token(self) -> str:
"""Get Google Cloud access token for API calls."""
try:
# Try to use default credentials
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(Request())
return credentials.token
except Exception as e:
LOG.error("Failed to get access token", error=str(e))
raise
def create_cache(
self,
model_name: str,
static_content: str,
cache_key: str,
ttl_seconds: int = 3600,
system_instruction: str | None = None,
) -> dict[str, Any]:
"""
Create a cache object using Vertex AI's /cachedContents API.
Args:
model_name: Full model path (e.g., "gemini-2.5-flash")
static_content: The static content to cache
cache_key: Unique key to identify this cache (e.g., f"task_{task_id}")
ttl_seconds: Time to live in seconds (default: 1 hour)
system_instruction: Optional system instruction to include
Returns:
Cache data with 'name', 'expireTime', etc.
"""
# Check if cache already exists for this key
if cache_key in self._cache_registry:
cache_data = self._cache_registry[cache_key]
# Check if still valid
expire_time = datetime.fromisoformat(cache_data["expireTime"].replace("Z", "+00:00"))
if expire_time > datetime.now(expire_time.tzinfo):
LOG.info("Reusing existing cache", cache_key=cache_key, cache_name=cache_data["name"])
return cache_data
else:
LOG.info("Cache expired, creating new one", cache_key=cache_key)
# Clean up expired cache
try:
self.delete_cache(cache_key)
except Exception:
pass # Best effort cleanup
url = f"https://{self.api_endpoint}/v1/projects/{self.project_id}/locations/{self.location}/cachedContents"
# Build the model path
full_model_path = f"projects/{self.project_id}/locations/{self.location}/publishers/google/models/{model_name}"
# Create payload
payload: dict[str, Any] = {
"model": full_model_path,
"contents": [{"role": "user", "parts": [{"text": static_content}]}],
"ttl": f"{ttl_seconds}s",
}
# Add system instruction if provided
if system_instruction:
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
headers = {"Authorization": f"Bearer {self._get_access_token()}", "Content-Type": "application/json"}
LOG.info(
"Creating Vertex AI cache object",
cache_key=cache_key,
model=model_name,
content_size=len(static_content),
ttl_seconds=ttl_seconds,
)
try:
response = requests.post(url, headers=headers, json=payload, timeout=30)
if response.status_code != 200:
LOG.error(
"Failed to create cache",
cache_key=cache_key,
status_code=response.status_code,
response=response.text,
)
raise Exception(f"Cache creation failed: {response.text}")
cache_data = response.json()
cache_name = cache_data["name"]
# Store in registry
self._cache_registry[cache_key] = cache_data
LOG.info(
"Cache created successfully",
cache_key=cache_key,
cache_name=cache_name,
expires_at=cache_data.get("expireTime"),
)
return cache_data
except requests.exceptions.Timeout:
LOG.error("Cache creation timed out", cache_key=cache_key)
raise
except Exception as e:
LOG.error("Cache creation failed", cache_key=cache_key, error=str(e))
raise
def delete_cache(self, cache_key: str) -> bool:
"""Delete a cache object."""
cache_data = self._cache_registry.get(cache_key)
if not cache_data:
LOG.warning("Cache not found in registry", cache_key=cache_key)
return False
cache_name = cache_data["name"]
url = f"https://{self.api_endpoint}/v1/{cache_name}"
headers = {
"Authorization": f"Bearer {self._get_access_token()}",
}
LOG.info("Deleting cache", cache_key=cache_key, cache_name=cache_name)
try:
response = requests.delete(url, headers=headers, timeout=10)
if response.status_code in (200, 204):
# Remove from registry
del self._cache_registry[cache_key]
LOG.info("Cache deleted successfully", cache_key=cache_key)
return True
else:
LOG.warning(
"Failed to delete cache",
cache_key=cache_key,
status_code=response.status_code,
response=response.text,
)
return False
except Exception as e:
LOG.error("Cache deletion failed", cache_key=cache_key, error=str(e))
return False
# Global cache manager instance
_global_cache_manager: VertexCacheManager | None = None
def get_cache_manager() -> VertexCacheManager:
"""Get or create the global cache manager instance."""
global _global_cache_manager
if _global_cache_manager is None:
project_id = settings.VERTEX_PROJECT_ID or "skyvern-production"
# Default to "global" to match the model configs in cloud/__init__.py
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
location = settings.VERTEX_LOCATION or "global"
_global_cache_manager = VertexCacheManager(project_id, location)
LOG.info("Created global cache manager", project_id=project_id, location=location)
return _global_cache_manager

View File

@@ -36,6 +36,7 @@ class SkyvernContext:
enable_parse_select_in_extract: bool = False
use_prompt_caching: bool = False
cached_static_prompt: str | None = None
vertex_cache_name: str | None = None # Vertex AI cache resource name for explicit caching
# script run context
script_id: str | None = None