Pedro/fix explicit caching vertex api (#3933)

This commit is contained in:
pedrohsdb
2025-11-06 14:47:58 -08:00
committed by GitHub
parent d2f4e27940
commit 44528cbd38
4 changed files with 340 additions and 40 deletions

View File

@@ -287,27 +287,8 @@ class LLMAPIHandlerFactory:
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
# Check if this is a Vertex AI model
if "vertex_ai/" in llm_config.model_name:
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied Vertex context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
ttl_seconds=3600,
)
# Check if this is an OpenAI model
elif (
if (
llm_config.model_name.startswith("gpt-")
or llm_config.model_name.startswith("o1-")
or llm_config.model_name.startswith("o3-")
@@ -582,27 +563,8 @@ class LLMAPIHandlerFactory:
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
# Check if this is a Vertex AI model
if "vertex_ai/" in llm_config.model_name:
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied Vertex context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
ttl_seconds=3600,
)
# Check if this is an OpenAI model
elif (
if (
llm_config.model_name.startswith("gpt-")
or llm_config.model_name.startswith("o1-")
or llm_config.model_name.startswith("o3-")
@@ -626,6 +588,24 @@ class LLMAPIHandlerFactory:
)
except Exception as e:
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
# Add Vertex AI cache reference only for the intended cached prompt
vertex_cache_attached = False
cache_resource_name = getattr(context, "vertex_cache_name", None)
if (
cache_resource_name
and "vertex_ai/" in model_name
and prompt_name == "extract-actions"
and getattr(context, "use_prompt_caching", False)
):
active_parameters["cached_content"] = cache_resource_name
vertex_cache_attached = True
LOG.info(
"Adding Vertex AI cache reference to request",
prompt_name=prompt_name,
cache_attached=True,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
@@ -633,6 +613,7 @@ class LLMAPIHandlerFactory:
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
"vertex_cache_attached": vertex_cache_attached,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
@@ -641,6 +622,7 @@ class LLMAPIHandlerFactory:
thought=thought,
ai_suggestion=ai_suggestion,
)
t_llm_request = time.perf_counter()
try:
# TODO (kerem): add a timeout to this call

View File

@@ -0,0 +1,203 @@
"""
Vertex AI Context Caching Manager.
This module implements the CORRECT caching pattern for Vertex AI using the /cachedContents API.
Unlike the Anthropic-style cache_control markers, Vertex AI requires:
1. Creating a cache object via POST to /cachedContents
2. Getting the cache resource name
3. Referencing that cache name in subsequent requests
"""
from datetime import datetime
from typing import Any
import google.auth
import requests
import structlog
from google.auth.transport.requests import Request
from skyvern.config import settings
LOG = structlog.get_logger()
class VertexCacheManager:
"""
Manages Vertex AI context caching using the correct /cachedContents API.
This provides guaranteed cache hits for static content across requests,
unlike implicit caching which requires exact prompt matches.
"""
def __init__(self, project_id: str, location: str = "global"):
self.project_id = project_id
self.location = location
# Use regional endpoint for non-global locations, global endpoint for global
if location == "global":
self.api_endpoint = "aiplatform.googleapis.com"
else:
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
def _get_access_token(self) -> str:
"""Get Google Cloud access token for API calls."""
try:
# Try to use default credentials
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(Request())
return credentials.token
except Exception as e:
LOG.error("Failed to get access token", error=str(e))
raise
def create_cache(
self,
model_name: str,
static_content: str,
cache_key: str,
ttl_seconds: int = 3600,
system_instruction: str | None = None,
) -> dict[str, Any]:
"""
Create a cache object using Vertex AI's /cachedContents API.
Args:
model_name: Full model path (e.g., "gemini-2.5-flash")
static_content: The static content to cache
cache_key: Unique key to identify this cache (e.g., f"task_{task_id}")
ttl_seconds: Time to live in seconds (default: 1 hour)
system_instruction: Optional system instruction to include
Returns:
Cache data with 'name', 'expireTime', etc.
"""
# Check if cache already exists for this key
if cache_key in self._cache_registry:
cache_data = self._cache_registry[cache_key]
# Check if still valid
expire_time = datetime.fromisoformat(cache_data["expireTime"].replace("Z", "+00:00"))
if expire_time > datetime.now(expire_time.tzinfo):
LOG.info("Reusing existing cache", cache_key=cache_key, cache_name=cache_data["name"])
return cache_data
else:
LOG.info("Cache expired, creating new one", cache_key=cache_key)
# Clean up expired cache
try:
self.delete_cache(cache_key)
except Exception:
pass # Best effort cleanup
url = f"https://{self.api_endpoint}/v1/projects/{self.project_id}/locations/{self.location}/cachedContents"
# Build the model path
full_model_path = f"projects/{self.project_id}/locations/{self.location}/publishers/google/models/{model_name}"
# Create payload
payload: dict[str, Any] = {
"model": full_model_path,
"contents": [{"role": "user", "parts": [{"text": static_content}]}],
"ttl": f"{ttl_seconds}s",
}
# Add system instruction if provided
if system_instruction:
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
headers = {"Authorization": f"Bearer {self._get_access_token()}", "Content-Type": "application/json"}
LOG.info(
"Creating Vertex AI cache object",
cache_key=cache_key,
model=model_name,
content_size=len(static_content),
ttl_seconds=ttl_seconds,
)
try:
response = requests.post(url, headers=headers, json=payload, timeout=30)
if response.status_code != 200:
LOG.error(
"Failed to create cache",
cache_key=cache_key,
status_code=response.status_code,
response=response.text,
)
raise Exception(f"Cache creation failed: {response.text}")
cache_data = response.json()
cache_name = cache_data["name"]
# Store in registry
self._cache_registry[cache_key] = cache_data
LOG.info(
"Cache created successfully",
cache_key=cache_key,
cache_name=cache_name,
expires_at=cache_data.get("expireTime"),
)
return cache_data
except requests.exceptions.Timeout:
LOG.error("Cache creation timed out", cache_key=cache_key)
raise
except Exception as e:
LOG.error("Cache creation failed", cache_key=cache_key, error=str(e))
raise
def delete_cache(self, cache_key: str) -> bool:
"""Delete a cache object."""
cache_data = self._cache_registry.get(cache_key)
if not cache_data:
LOG.warning("Cache not found in registry", cache_key=cache_key)
return False
cache_name = cache_data["name"]
url = f"https://{self.api_endpoint}/v1/{cache_name}"
headers = {
"Authorization": f"Bearer {self._get_access_token()}",
}
LOG.info("Deleting cache", cache_key=cache_key, cache_name=cache_name)
try:
response = requests.delete(url, headers=headers, timeout=10)
if response.status_code in (200, 204):
# Remove from registry
del self._cache_registry[cache_key]
LOG.info("Cache deleted successfully", cache_key=cache_key)
return True
else:
LOG.warning(
"Failed to delete cache",
cache_key=cache_key,
status_code=response.status_code,
response=response.text,
)
return False
except Exception as e:
LOG.error("Cache deletion failed", cache_key=cache_key, error=str(e))
return False
# Global cache manager instance
_global_cache_manager: VertexCacheManager | None = None
def get_cache_manager() -> VertexCacheManager:
"""Get or create the global cache manager instance."""
global _global_cache_manager
if _global_cache_manager is None:
project_id = settings.VERTEX_PROJECT_ID or "skyvern-production"
# Default to "global" to match the model configs in cloud/__init__.py
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
location = settings.VERTEX_LOCATION or "global"
_global_cache_manager = VertexCacheManager(project_id, location)
LOG.info("Created global cache manager", project_id=project_id, location=location)
return _global_cache_manager