Pedro/fix explicit caching vertex api (#3933)
This commit is contained in:
@@ -3,6 +3,7 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
from asyncio.exceptions import CancelledError
|
||||
from datetime import UTC, datetime
|
||||
@@ -71,11 +72,14 @@ from skyvern.forge.sdk.api.files import (
|
||||
wait_for_download_finished,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLM_PROVIDER_ERROR_RETRYABLE_TASK_TYPE, LLM_PROVIDER_ERROR_TYPE
|
||||
from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller
|
||||
from skyvern.forge.sdk.api.llm.vertex_cache_manager import get_cache_manager
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
@@ -2149,6 +2153,108 @@ class ForgeAgent:
|
||||
|
||||
return scraped_page, extract_action_prompt, use_caching
|
||||
|
||||
async def _create_vertex_cache_for_task(self, task: Task, static_prompt: str, context: SkyvernContext) -> None:
|
||||
"""
|
||||
Create a Vertex AI cache for the task's static prompt.
|
||||
|
||||
Uses llm_key as cache key to enable cache sharing across tasks with the same model.
|
||||
|
||||
Args:
|
||||
task: The task to create cache for
|
||||
static_prompt: The static prompt content to cache
|
||||
context: The Skyvern context to store the cache name in
|
||||
"""
|
||||
# Early return if task doesn't have an llm_key
|
||||
# This should not happen given the guard at the call site, but being defensive
|
||||
if not task.llm_key:
|
||||
LOG.warning(
|
||||
"Cannot create Vertex AI cache without llm_key, skipping cache creation",
|
||||
task_id=task.task_id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# Use llm_key as cache_key so all tasks with the same model share the same cache
|
||||
# This maximizes cache reuse and reduces cache storage costs
|
||||
cache_key = f"extract-action-static-{task.llm_key}"
|
||||
|
||||
# Get the actual model name from LLM config to ensure correct format
|
||||
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
|
||||
model_name = "gemini-2.5-flash" # Default
|
||||
|
||||
try:
|
||||
llm_config = LLMConfigRegistry.get_config(task.llm_key)
|
||||
extracted_name = None
|
||||
|
||||
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
|
||||
if hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str):
|
||||
if "vertex_ai/" in llm_config.model_name:
|
||||
# Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash"
|
||||
extracted_name = llm_config.model_name.split("/")[-1]
|
||||
elif llm_config.model_name.startswith("gemini-"):
|
||||
# Already in correct format
|
||||
extracted_name = llm_config.model_name
|
||||
|
||||
# For router/fallback configs, extract from api_base or infer from key name
|
||||
if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params:
|
||||
params = llm_config.litellm_params
|
||||
api_base = getattr(params, "api_base", None)
|
||||
if api_base and isinstance(api_base, str) and "/models/" in api_base:
|
||||
# Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash"
|
||||
extracted_name = api_base.split("/models/")[-1]
|
||||
|
||||
# For router configs without api_base, infer from the llm_key itself
|
||||
if not extracted_name:
|
||||
# Extract version from llm_key (e.g., VERTEX_GEMINI_1_5_FLASH -> "1_5" or VERTEX_GEMINI_2.5_FLASH -> "2.5")
|
||||
# Pattern: GEMINI_{version}_{flavor} where version can use dots, underscores, or dashes
|
||||
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", task.llm_key, re.IGNORECASE)
|
||||
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
|
||||
|
||||
# Determine flavor
|
||||
if "_PRO_" in task.llm_key or task.llm_key.endswith("_PRO"):
|
||||
extracted_name = f"gemini-{version}-pro"
|
||||
elif "_FLASH_LITE_" in task.llm_key or task.llm_key.endswith("_FLASH_LITE"):
|
||||
extracted_name = f"gemini-{version}-flash-lite"
|
||||
else:
|
||||
# Default to flash flavor
|
||||
extracted_name = f"gemini-{version}-flash"
|
||||
|
||||
if extracted_name:
|
||||
model_name = extracted_name
|
||||
except Exception as e:
|
||||
LOG.debug("Failed to extract model name from config, using default", error=str(e))
|
||||
|
||||
# Create cache for this task
|
||||
# Use asyncio.to_thread to offload blocking HTTP request (requests.post)
|
||||
# This prevents freezing the event loop during cache creation
|
||||
cache_data = await asyncio.to_thread(
|
||||
cache_manager.create_cache,
|
||||
model_name=model_name,
|
||||
static_content=static_prompt,
|
||||
cache_key=cache_key,
|
||||
ttl_seconds=3600, # 1 hour
|
||||
)
|
||||
|
||||
# Store cache resource name in context
|
||||
context.vertex_cache_name = cache_data["name"]
|
||||
|
||||
LOG.info(
|
||||
"Created Vertex AI cache for task",
|
||||
task_id=task.task_id,
|
||||
cache_key=cache_key,
|
||||
cache_name=cache_data["name"],
|
||||
model_name=model_name,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(
|
||||
"Failed to create Vertex AI cache, proceeding without caching",
|
||||
task_id=task.task_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _build_extract_action_prompt(
|
||||
self,
|
||||
task: Task,
|
||||
@@ -2243,6 +2349,11 @@ class ForgeAgent:
|
||||
# Store static prompt for caching and return dynamic prompt
|
||||
context.cached_static_prompt = static_prompt
|
||||
use_caching = True
|
||||
|
||||
# Create Vertex AI cache for Gemini models
|
||||
if task.llm_key and "GEMINI" in task.llm_key:
|
||||
await self._create_vertex_cache_for_task(task, static_prompt, context)
|
||||
|
||||
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
|
||||
return dynamic_prompt, use_caching
|
||||
|
||||
@@ -2524,6 +2635,9 @@ class ForgeAgent:
|
||||
raise TaskNotFound(task_id=task.task_id) from e
|
||||
task = refreshed_task
|
||||
|
||||
# Caches expire based on TTL (1 hour) or can be cleaned up via scheduled job
|
||||
# This allows multiple tasks with the same llm_key to share the same cache
|
||||
|
||||
# log the task status as an event
|
||||
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})
|
||||
|
||||
|
||||
@@ -287,27 +287,8 @@ class LLMAPIHandlerFactory:
|
||||
and isinstance(llm_config, LLMConfig)
|
||||
and isinstance(llm_config.model_name, str)
|
||||
):
|
||||
# Check if this is a Vertex AI model
|
||||
if "vertex_ai/" in llm_config.model_name:
|
||||
caching_system_message = {
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": context_cached_static_prompt,
|
||||
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
|
||||
}
|
||||
],
|
||||
}
|
||||
messages = [caching_system_message] + messages
|
||||
LOG.info(
|
||||
"Applied Vertex context caching",
|
||||
prompt_name=prompt_name,
|
||||
model=llm_config.model_name,
|
||||
ttl_seconds=3600,
|
||||
)
|
||||
# Check if this is an OpenAI model
|
||||
elif (
|
||||
if (
|
||||
llm_config.model_name.startswith("gpt-")
|
||||
or llm_config.model_name.startswith("o1-")
|
||||
or llm_config.model_name.startswith("o3-")
|
||||
@@ -582,27 +563,8 @@ class LLMAPIHandlerFactory:
|
||||
and isinstance(llm_config, LLMConfig)
|
||||
and isinstance(llm_config.model_name, str)
|
||||
):
|
||||
# Check if this is a Vertex AI model
|
||||
if "vertex_ai/" in llm_config.model_name:
|
||||
caching_system_message = {
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": context_cached_static_prompt,
|
||||
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
|
||||
}
|
||||
],
|
||||
}
|
||||
messages = [caching_system_message] + messages
|
||||
LOG.info(
|
||||
"Applied Vertex context caching",
|
||||
prompt_name=prompt_name,
|
||||
model=llm_config.model_name,
|
||||
ttl_seconds=3600,
|
||||
)
|
||||
# Check if this is an OpenAI model
|
||||
elif (
|
||||
if (
|
||||
llm_config.model_name.startswith("gpt-")
|
||||
or llm_config.model_name.startswith("o1-")
|
||||
or llm_config.model_name.startswith("o3-")
|
||||
@@ -626,6 +588,24 @@ class LLMAPIHandlerFactory:
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
|
||||
|
||||
# Add Vertex AI cache reference only for the intended cached prompt
|
||||
vertex_cache_attached = False
|
||||
cache_resource_name = getattr(context, "vertex_cache_name", None)
|
||||
if (
|
||||
cache_resource_name
|
||||
and "vertex_ai/" in model_name
|
||||
and prompt_name == "extract-actions"
|
||||
and getattr(context, "use_prompt_caching", False)
|
||||
):
|
||||
active_parameters["cached_content"] = cache_resource_name
|
||||
vertex_cache_attached = True
|
||||
LOG.info(
|
||||
"Adding Vertex AI cache reference to request",
|
||||
prompt_name=prompt_name,
|
||||
cache_attached=True,
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(
|
||||
{
|
||||
@@ -633,6 +613,7 @@ class LLMAPIHandlerFactory:
|
||||
"messages": messages,
|
||||
# we're not using active_parameters here because it may contain sensitive information
|
||||
**parameters,
|
||||
"vertex_cache_attached": vertex_cache_attached,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
@@ -641,6 +622,7 @@ class LLMAPIHandlerFactory:
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
|
||||
t_llm_request = time.perf_counter()
|
||||
try:
|
||||
# TODO (kerem): add a timeout to this call
|
||||
|
||||
203
skyvern/forge/sdk/api/llm/vertex_cache_manager.py
Normal file
203
skyvern/forge/sdk/api/llm/vertex_cache_manager.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Vertex AI Context Caching Manager.
|
||||
|
||||
This module implements the CORRECT caching pattern for Vertex AI using the /cachedContents API.
|
||||
Unlike the Anthropic-style cache_control markers, Vertex AI requires:
|
||||
1. Creating a cache object via POST to /cachedContents
|
||||
2. Getting the cache resource name
|
||||
3. Referencing that cache name in subsequent requests
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import google.auth
|
||||
import requests
|
||||
import structlog
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
from skyvern.config import settings
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class VertexCacheManager:
|
||||
"""
|
||||
Manages Vertex AI context caching using the correct /cachedContents API.
|
||||
|
||||
This provides guaranteed cache hits for static content across requests,
|
||||
unlike implicit caching which requires exact prompt matches.
|
||||
"""
|
||||
|
||||
def __init__(self, project_id: str, location: str = "global"):
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
# Use regional endpoint for non-global locations, global endpoint for global
|
||||
if location == "global":
|
||||
self.api_endpoint = "aiplatform.googleapis.com"
|
||||
else:
|
||||
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
|
||||
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Get Google Cloud access token for API calls."""
|
||||
try:
|
||||
# Try to use default credentials
|
||||
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||
credentials.refresh(Request())
|
||||
return credentials.token
|
||||
except Exception as e:
|
||||
LOG.error("Failed to get access token", error=str(e))
|
||||
raise
|
||||
|
||||
def create_cache(
|
||||
self,
|
||||
model_name: str,
|
||||
static_content: str,
|
||||
cache_key: str,
|
||||
ttl_seconds: int = 3600,
|
||||
system_instruction: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a cache object using Vertex AI's /cachedContents API.
|
||||
|
||||
Args:
|
||||
model_name: Full model path (e.g., "gemini-2.5-flash")
|
||||
static_content: The static content to cache
|
||||
cache_key: Unique key to identify this cache (e.g., f"task_{task_id}")
|
||||
ttl_seconds: Time to live in seconds (default: 1 hour)
|
||||
system_instruction: Optional system instruction to include
|
||||
|
||||
Returns:
|
||||
Cache data with 'name', 'expireTime', etc.
|
||||
"""
|
||||
# Check if cache already exists for this key
|
||||
if cache_key in self._cache_registry:
|
||||
cache_data = self._cache_registry[cache_key]
|
||||
# Check if still valid
|
||||
expire_time = datetime.fromisoformat(cache_data["expireTime"].replace("Z", "+00:00"))
|
||||
if expire_time > datetime.now(expire_time.tzinfo):
|
||||
LOG.info("Reusing existing cache", cache_key=cache_key, cache_name=cache_data["name"])
|
||||
return cache_data
|
||||
else:
|
||||
LOG.info("Cache expired, creating new one", cache_key=cache_key)
|
||||
# Clean up expired cache
|
||||
try:
|
||||
self.delete_cache(cache_key)
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
|
||||
url = f"https://{self.api_endpoint}/v1/projects/{self.project_id}/locations/{self.location}/cachedContents"
|
||||
|
||||
# Build the model path
|
||||
full_model_path = f"projects/{self.project_id}/locations/{self.location}/publishers/google/models/{model_name}"
|
||||
|
||||
# Create payload
|
||||
payload: dict[str, Any] = {
|
||||
"model": full_model_path,
|
||||
"contents": [{"role": "user", "parts": [{"text": static_content}]}],
|
||||
"ttl": f"{ttl_seconds}s",
|
||||
}
|
||||
|
||||
# Add system instruction if provided
|
||||
if system_instruction:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
||||
|
||||
headers = {"Authorization": f"Bearer {self._get_access_token()}", "Content-Type": "application/json"}
|
||||
|
||||
LOG.info(
|
||||
"Creating Vertex AI cache object",
|
||||
cache_key=cache_key,
|
||||
model=model_name,
|
||||
content_size=len(static_content),
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=30)
|
||||
|
||||
if response.status_code != 200:
|
||||
LOG.error(
|
||||
"Failed to create cache",
|
||||
cache_key=cache_key,
|
||||
status_code=response.status_code,
|
||||
response=response.text,
|
||||
)
|
||||
raise Exception(f"Cache creation failed: {response.text}")
|
||||
|
||||
cache_data = response.json()
|
||||
cache_name = cache_data["name"]
|
||||
|
||||
# Store in registry
|
||||
self._cache_registry[cache_key] = cache_data
|
||||
|
||||
LOG.info(
|
||||
"Cache created successfully",
|
||||
cache_key=cache_key,
|
||||
cache_name=cache_name,
|
||||
expires_at=cache_data.get("expireTime"),
|
||||
)
|
||||
|
||||
return cache_data
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
LOG.error("Cache creation timed out", cache_key=cache_key)
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error("Cache creation failed", cache_key=cache_key, error=str(e))
|
||||
raise
|
||||
|
||||
def delete_cache(self, cache_key: str) -> bool:
|
||||
"""Delete a cache object."""
|
||||
cache_data = self._cache_registry.get(cache_key)
|
||||
if not cache_data:
|
||||
LOG.warning("Cache not found in registry", cache_key=cache_key)
|
||||
return False
|
||||
|
||||
cache_name = cache_data["name"]
|
||||
url = f"https://{self.api_endpoint}/v1/{cache_name}"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
|
||||
LOG.info("Deleting cache", cache_key=cache_key, cache_name=cache_name)
|
||||
|
||||
try:
|
||||
response = requests.delete(url, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code in (200, 204):
|
||||
# Remove from registry
|
||||
del self._cache_registry[cache_key]
|
||||
LOG.info("Cache deleted successfully", cache_key=cache_key)
|
||||
return True
|
||||
else:
|
||||
LOG.warning(
|
||||
"Failed to delete cache",
|
||||
cache_key=cache_key,
|
||||
status_code=response.status_code,
|
||||
response=response.text,
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
LOG.error("Cache deletion failed", cache_key=cache_key, error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
# Global cache manager instance
|
||||
_global_cache_manager: VertexCacheManager | None = None
|
||||
|
||||
|
||||
def get_cache_manager() -> VertexCacheManager:
|
||||
"""Get or create the global cache manager instance."""
|
||||
global _global_cache_manager
|
||||
|
||||
if _global_cache_manager is None:
|
||||
project_id = settings.VERTEX_PROJECT_ID or "skyvern-production"
|
||||
# Default to "global" to match the model configs in cloud/__init__.py
|
||||
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
|
||||
location = settings.VERTEX_LOCATION or "global"
|
||||
_global_cache_manager = VertexCacheManager(project_id, location)
|
||||
LOG.info("Created global cache manager", project_id=project_id, location=location)
|
||||
|
||||
return _global_cache_manager
|
||||
@@ -36,6 +36,7 @@ class SkyvernContext:
|
||||
enable_parse_select_in_extract: bool = False
|
||||
use_prompt_caching: bool = False
|
||||
cached_static_prompt: str | None = None
|
||||
vertex_cache_name: str | None = None # Vertex AI cache resource name for explicit caching
|
||||
|
||||
# script run context
|
||||
script_id: str | None = None
|
||||
|
||||
Reference in New Issue
Block a user