Pedro/fix explicit caching vertex api (#3933)
This commit is contained in:
@@ -3,6 +3,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
import string
|
import string
|
||||||
from asyncio.exceptions import CancelledError
|
from asyncio.exceptions import CancelledError
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
@@ -71,11 +72,14 @@ from skyvern.forge.sdk.api.files import (
|
|||||||
wait_for_download_finished,
|
wait_for_download_finished,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
|
||||||
|
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||||
from skyvern.forge.sdk.api.llm.exceptions import LLM_PROVIDER_ERROR_RETRYABLE_TASK_TYPE, LLM_PROVIDER_ERROR_TYPE
|
from skyvern.forge.sdk.api.llm.exceptions import LLM_PROVIDER_ERROR_RETRYABLE_TASK_TYPE, LLM_PROVIDER_ERROR_TYPE
|
||||||
from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller
|
from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller
|
||||||
|
from skyvern.forge.sdk.api.llm.vertex_cache_manager import get_cache_manager
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
||||||
|
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||||
from skyvern.forge.sdk.db.enums import TaskType
|
from skyvern.forge.sdk.db.enums import TaskType
|
||||||
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
|
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
|
||||||
from skyvern.forge.sdk.models import Step, StepStatus
|
from skyvern.forge.sdk.models import Step, StepStatus
|
||||||
@@ -2149,6 +2153,108 @@ class ForgeAgent:
|
|||||||
|
|
||||||
return scraped_page, extract_action_prompt, use_caching
|
return scraped_page, extract_action_prompt, use_caching
|
||||||
|
|
||||||
|
async def _create_vertex_cache_for_task(self, task: Task, static_prompt: str, context: SkyvernContext) -> None:
|
||||||
|
"""
|
||||||
|
Create a Vertex AI cache for the task's static prompt.
|
||||||
|
|
||||||
|
Uses llm_key as cache key to enable cache sharing across tasks with the same model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task to create cache for
|
||||||
|
static_prompt: The static prompt content to cache
|
||||||
|
context: The Skyvern context to store the cache name in
|
||||||
|
"""
|
||||||
|
# Early return if task doesn't have an llm_key
|
||||||
|
# This should not happen given the guard at the call site, but being defensive
|
||||||
|
if not task.llm_key:
|
||||||
|
LOG.warning(
|
||||||
|
"Cannot create Vertex AI cache without llm_key, skipping cache creation",
|
||||||
|
task_id=task.task_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
cache_manager = get_cache_manager()
|
||||||
|
|
||||||
|
# Use llm_key as cache_key so all tasks with the same model share the same cache
|
||||||
|
# This maximizes cache reuse and reduces cache storage costs
|
||||||
|
cache_key = f"extract-action-static-{task.llm_key}"
|
||||||
|
|
||||||
|
# Get the actual model name from LLM config to ensure correct format
|
||||||
|
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
|
||||||
|
model_name = "gemini-2.5-flash" # Default
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_config = LLMConfigRegistry.get_config(task.llm_key)
|
||||||
|
extracted_name = None
|
||||||
|
|
||||||
|
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
|
||||||
|
if hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str):
|
||||||
|
if "vertex_ai/" in llm_config.model_name:
|
||||||
|
# Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash"
|
||||||
|
extracted_name = llm_config.model_name.split("/")[-1]
|
||||||
|
elif llm_config.model_name.startswith("gemini-"):
|
||||||
|
# Already in correct format
|
||||||
|
extracted_name = llm_config.model_name
|
||||||
|
|
||||||
|
# For router/fallback configs, extract from api_base or infer from key name
|
||||||
|
if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params:
|
||||||
|
params = llm_config.litellm_params
|
||||||
|
api_base = getattr(params, "api_base", None)
|
||||||
|
if api_base and isinstance(api_base, str) and "/models/" in api_base:
|
||||||
|
# Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash"
|
||||||
|
extracted_name = api_base.split("/models/")[-1]
|
||||||
|
|
||||||
|
# For router configs without api_base, infer from the llm_key itself
|
||||||
|
if not extracted_name:
|
||||||
|
# Extract version from llm_key (e.g., VERTEX_GEMINI_1_5_FLASH -> "1_5" or VERTEX_GEMINI_2.5_FLASH -> "2.5")
|
||||||
|
# Pattern: GEMINI_{version}_{flavor} where version can use dots, underscores, or dashes
|
||||||
|
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", task.llm_key, re.IGNORECASE)
|
||||||
|
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
|
||||||
|
|
||||||
|
# Determine flavor
|
||||||
|
if "_PRO_" in task.llm_key or task.llm_key.endswith("_PRO"):
|
||||||
|
extracted_name = f"gemini-{version}-pro"
|
||||||
|
elif "_FLASH_LITE_" in task.llm_key or task.llm_key.endswith("_FLASH_LITE"):
|
||||||
|
extracted_name = f"gemini-{version}-flash-lite"
|
||||||
|
else:
|
||||||
|
# Default to flash flavor
|
||||||
|
extracted_name = f"gemini-{version}-flash"
|
||||||
|
|
||||||
|
if extracted_name:
|
||||||
|
model_name = extracted_name
|
||||||
|
except Exception as e:
|
||||||
|
LOG.debug("Failed to extract model name from config, using default", error=str(e))
|
||||||
|
|
||||||
|
# Create cache for this task
|
||||||
|
# Use asyncio.to_thread to offload blocking HTTP request (requests.post)
|
||||||
|
# This prevents freezing the event loop during cache creation
|
||||||
|
cache_data = await asyncio.to_thread(
|
||||||
|
cache_manager.create_cache,
|
||||||
|
model_name=model_name,
|
||||||
|
static_content=static_prompt,
|
||||||
|
cache_key=cache_key,
|
||||||
|
ttl_seconds=3600, # 1 hour
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store cache resource name in context
|
||||||
|
context.vertex_cache_name = cache_data["name"]
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Created Vertex AI cache for task",
|
||||||
|
task_id=task.task_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
cache_name=cache_data["name"],
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(
|
||||||
|
"Failed to create Vertex AI cache, proceeding without caching",
|
||||||
|
task_id=task.task_id,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def _build_extract_action_prompt(
|
async def _build_extract_action_prompt(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
@@ -2243,6 +2349,11 @@ class ForgeAgent:
|
|||||||
# Store static prompt for caching and return dynamic prompt
|
# Store static prompt for caching and return dynamic prompt
|
||||||
context.cached_static_prompt = static_prompt
|
context.cached_static_prompt = static_prompt
|
||||||
use_caching = True
|
use_caching = True
|
||||||
|
|
||||||
|
# Create Vertex AI cache for Gemini models
|
||||||
|
if task.llm_key and "GEMINI" in task.llm_key:
|
||||||
|
await self._create_vertex_cache_for_task(task, static_prompt, context)
|
||||||
|
|
||||||
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
|
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
|
||||||
return dynamic_prompt, use_caching
|
return dynamic_prompt, use_caching
|
||||||
|
|
||||||
@@ -2524,6 +2635,9 @@ class ForgeAgent:
|
|||||||
raise TaskNotFound(task_id=task.task_id) from e
|
raise TaskNotFound(task_id=task.task_id) from e
|
||||||
task = refreshed_task
|
task = refreshed_task
|
||||||
|
|
||||||
|
# Caches expire based on TTL (1 hour) or can be cleaned up via scheduled job
|
||||||
|
# This allows multiple tasks with the same llm_key to share the same cache
|
||||||
|
|
||||||
# log the task status as an event
|
# log the task status as an event
|
||||||
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})
|
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})
|
||||||
|
|
||||||
|
|||||||
@@ -287,27 +287,8 @@ class LLMAPIHandlerFactory:
|
|||||||
and isinstance(llm_config, LLMConfig)
|
and isinstance(llm_config, LLMConfig)
|
||||||
and isinstance(llm_config.model_name, str)
|
and isinstance(llm_config.model_name, str)
|
||||||
):
|
):
|
||||||
# Check if this is a Vertex AI model
|
|
||||||
if "vertex_ai/" in llm_config.model_name:
|
|
||||||
caching_system_message = {
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": context_cached_static_prompt,
|
|
||||||
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
messages = [caching_system_message] + messages
|
|
||||||
LOG.info(
|
|
||||||
"Applied Vertex context caching",
|
|
||||||
prompt_name=prompt_name,
|
|
||||||
model=llm_config.model_name,
|
|
||||||
ttl_seconds=3600,
|
|
||||||
)
|
|
||||||
# Check if this is an OpenAI model
|
# Check if this is an OpenAI model
|
||||||
elif (
|
if (
|
||||||
llm_config.model_name.startswith("gpt-")
|
llm_config.model_name.startswith("gpt-")
|
||||||
or llm_config.model_name.startswith("o1-")
|
or llm_config.model_name.startswith("o1-")
|
||||||
or llm_config.model_name.startswith("o3-")
|
or llm_config.model_name.startswith("o3-")
|
||||||
@@ -582,27 +563,8 @@ class LLMAPIHandlerFactory:
|
|||||||
and isinstance(llm_config, LLMConfig)
|
and isinstance(llm_config, LLMConfig)
|
||||||
and isinstance(llm_config.model_name, str)
|
and isinstance(llm_config.model_name, str)
|
||||||
):
|
):
|
||||||
# Check if this is a Vertex AI model
|
|
||||||
if "vertex_ai/" in llm_config.model_name:
|
|
||||||
caching_system_message = {
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": context_cached_static_prompt,
|
|
||||||
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
messages = [caching_system_message] + messages
|
|
||||||
LOG.info(
|
|
||||||
"Applied Vertex context caching",
|
|
||||||
prompt_name=prompt_name,
|
|
||||||
model=llm_config.model_name,
|
|
||||||
ttl_seconds=3600,
|
|
||||||
)
|
|
||||||
# Check if this is an OpenAI model
|
# Check if this is an OpenAI model
|
||||||
elif (
|
if (
|
||||||
llm_config.model_name.startswith("gpt-")
|
llm_config.model_name.startswith("gpt-")
|
||||||
or llm_config.model_name.startswith("o1-")
|
or llm_config.model_name.startswith("o1-")
|
||||||
or llm_config.model_name.startswith("o3-")
|
or llm_config.model_name.startswith("o3-")
|
||||||
@@ -626,6 +588,24 @@ class LLMAPIHandlerFactory:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
|
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
|
||||||
|
|
||||||
|
# Add Vertex AI cache reference only for the intended cached prompt
|
||||||
|
vertex_cache_attached = False
|
||||||
|
cache_resource_name = getattr(context, "vertex_cache_name", None)
|
||||||
|
if (
|
||||||
|
cache_resource_name
|
||||||
|
and "vertex_ai/" in model_name
|
||||||
|
and prompt_name == "extract-actions"
|
||||||
|
and getattr(context, "use_prompt_caching", False)
|
||||||
|
):
|
||||||
|
active_parameters["cached_content"] = cache_resource_name
|
||||||
|
vertex_cache_attached = True
|
||||||
|
LOG.info(
|
||||||
|
"Adding Vertex AI cache reference to request",
|
||||||
|
prompt_name=prompt_name,
|
||||||
|
cache_attached=True,
|
||||||
|
)
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
@@ -633,6 +613,7 @@ class LLMAPIHandlerFactory:
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
# we're not using active_parameters here because it may contain sensitive information
|
# we're not using active_parameters here because it may contain sensitive information
|
||||||
**parameters,
|
**parameters,
|
||||||
|
"vertex_cache_attached": vertex_cache_attached,
|
||||||
}
|
}
|
||||||
).encode("utf-8"),
|
).encode("utf-8"),
|
||||||
artifact_type=ArtifactType.LLM_REQUEST,
|
artifact_type=ArtifactType.LLM_REQUEST,
|
||||||
@@ -641,6 +622,7 @@ class LLMAPIHandlerFactory:
|
|||||||
thought=thought,
|
thought=thought,
|
||||||
ai_suggestion=ai_suggestion,
|
ai_suggestion=ai_suggestion,
|
||||||
)
|
)
|
||||||
|
|
||||||
t_llm_request = time.perf_counter()
|
t_llm_request = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
# TODO (kerem): add a timeout to this call
|
# TODO (kerem): add a timeout to this call
|
||||||
|
|||||||
203
skyvern/forge/sdk/api/llm/vertex_cache_manager.py
Normal file
203
skyvern/forge/sdk/api/llm/vertex_cache_manager.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""
|
||||||
|
Vertex AI Context Caching Manager.
|
||||||
|
|
||||||
|
This module implements the CORRECT caching pattern for Vertex AI using the /cachedContents API.
|
||||||
|
Unlike the Anthropic-style cache_control markers, Vertex AI requires:
|
||||||
|
1. Creating a cache object via POST to /cachedContents
|
||||||
|
2. Getting the cache resource name
|
||||||
|
3. Referencing that cache name in subsequent requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import google.auth
|
||||||
|
import requests
|
||||||
|
import structlog
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class VertexCacheManager:
|
||||||
|
"""
|
||||||
|
Manages Vertex AI context caching using the correct /cachedContents API.
|
||||||
|
|
||||||
|
This provides guaranteed cache hits for static content across requests,
|
||||||
|
unlike implicit caching which requires exact prompt matches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, project_id: str, location: str = "global"):
|
||||||
|
self.project_id = project_id
|
||||||
|
self.location = location
|
||||||
|
# Use regional endpoint for non-global locations, global endpoint for global
|
||||||
|
if location == "global":
|
||||||
|
self.api_endpoint = "aiplatform.googleapis.com"
|
||||||
|
else:
|
||||||
|
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
|
||||||
|
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
|
||||||
|
|
||||||
|
def _get_access_token(self) -> str:
|
||||||
|
"""Get Google Cloud access token for API calls."""
|
||||||
|
try:
|
||||||
|
# Try to use default credentials
|
||||||
|
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||||
|
credentials.refresh(Request())
|
||||||
|
return credentials.token
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error("Failed to get access token", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_cache(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
static_content: str,
|
||||||
|
cache_key: str,
|
||||||
|
ttl_seconds: int = 3600,
|
||||||
|
system_instruction: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a cache object using Vertex AI's /cachedContents API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Full model path (e.g., "gemini-2.5-flash")
|
||||||
|
static_content: The static content to cache
|
||||||
|
cache_key: Unique key to identify this cache (e.g., f"task_{task_id}")
|
||||||
|
ttl_seconds: Time to live in seconds (default: 1 hour)
|
||||||
|
system_instruction: Optional system instruction to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache data with 'name', 'expireTime', etc.
|
||||||
|
"""
|
||||||
|
# Check if cache already exists for this key
|
||||||
|
if cache_key in self._cache_registry:
|
||||||
|
cache_data = self._cache_registry[cache_key]
|
||||||
|
# Check if still valid
|
||||||
|
expire_time = datetime.fromisoformat(cache_data["expireTime"].replace("Z", "+00:00"))
|
||||||
|
if expire_time > datetime.now(expire_time.tzinfo):
|
||||||
|
LOG.info("Reusing existing cache", cache_key=cache_key, cache_name=cache_data["name"])
|
||||||
|
return cache_data
|
||||||
|
else:
|
||||||
|
LOG.info("Cache expired, creating new one", cache_key=cache_key)
|
||||||
|
# Clean up expired cache
|
||||||
|
try:
|
||||||
|
self.delete_cache(cache_key)
|
||||||
|
except Exception:
|
||||||
|
pass # Best effort cleanup
|
||||||
|
|
||||||
|
url = f"https://{self.api_endpoint}/v1/projects/{self.project_id}/locations/{self.location}/cachedContents"
|
||||||
|
|
||||||
|
# Build the model path
|
||||||
|
full_model_path = f"projects/{self.project_id}/locations/{self.location}/publishers/google/models/{model_name}"
|
||||||
|
|
||||||
|
# Create payload
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"model": full_model_path,
|
||||||
|
"contents": [{"role": "user", "parts": [{"text": static_content}]}],
|
||||||
|
"ttl": f"{ttl_seconds}s",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add system instruction if provided
|
||||||
|
if system_instruction:
|
||||||
|
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {self._get_access_token()}", "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Creating Vertex AI cache object",
|
||||||
|
cache_key=cache_key,
|
||||||
|
model=model_name,
|
||||||
|
content_size=len(static_content),
|
||||||
|
ttl_seconds=ttl_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=payload, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
LOG.error(
|
||||||
|
"Failed to create cache",
|
||||||
|
cache_key=cache_key,
|
||||||
|
status_code=response.status_code,
|
||||||
|
response=response.text,
|
||||||
|
)
|
||||||
|
raise Exception(f"Cache creation failed: {response.text}")
|
||||||
|
|
||||||
|
cache_data = response.json()
|
||||||
|
cache_name = cache_data["name"]
|
||||||
|
|
||||||
|
# Store in registry
|
||||||
|
self._cache_registry[cache_key] = cache_data
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"Cache created successfully",
|
||||||
|
cache_key=cache_key,
|
||||||
|
cache_name=cache_name,
|
||||||
|
expires_at=cache_data.get("expireTime"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cache_data
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
LOG.error("Cache creation timed out", cache_key=cache_key)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error("Cache creation failed", cache_key=cache_key, error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_cache(self, cache_key: str) -> bool:
|
||||||
|
"""Delete a cache object."""
|
||||||
|
cache_data = self._cache_registry.get(cache_key)
|
||||||
|
if not cache_data:
|
||||||
|
LOG.warning("Cache not found in registry", cache_key=cache_key)
|
||||||
|
return False
|
||||||
|
|
||||||
|
cache_name = cache_data["name"]
|
||||||
|
url = f"https://{self.api_endpoint}/v1/{cache_name}"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self._get_access_token()}",
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG.info("Deleting cache", cache_key=cache_key, cache_name=cache_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.delete(url, headers=headers, timeout=10)
|
||||||
|
|
||||||
|
if response.status_code in (200, 204):
|
||||||
|
# Remove from registry
|
||||||
|
del self._cache_registry[cache_key]
|
||||||
|
LOG.info("Cache deleted successfully", cache_key=cache_key)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
"Failed to delete cache",
|
||||||
|
cache_key=cache_key,
|
||||||
|
status_code=response.status_code,
|
||||||
|
response=response.text,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error("Cache deletion failed", cache_key=cache_key, error=str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Global cache manager instance
|
||||||
|
_global_cache_manager: VertexCacheManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_manager() -> VertexCacheManager:
|
||||||
|
"""Get or create the global cache manager instance."""
|
||||||
|
global _global_cache_manager
|
||||||
|
|
||||||
|
if _global_cache_manager is None:
|
||||||
|
project_id = settings.VERTEX_PROJECT_ID or "skyvern-production"
|
||||||
|
# Default to "global" to match the model configs in cloud/__init__.py
|
||||||
|
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
|
||||||
|
location = settings.VERTEX_LOCATION or "global"
|
||||||
|
_global_cache_manager = VertexCacheManager(project_id, location)
|
||||||
|
LOG.info("Created global cache manager", project_id=project_id, location=location)
|
||||||
|
|
||||||
|
return _global_cache_manager
|
||||||
@@ -36,6 +36,7 @@ class SkyvernContext:
|
|||||||
enable_parse_select_in_extract: bool = False
|
enable_parse_select_in_extract: bool = False
|
||||||
use_prompt_caching: bool = False
|
use_prompt_caching: bool = False
|
||||||
cached_static_prompt: str | None = None
|
cached_static_prompt: str | None = None
|
||||||
|
vertex_cache_name: str | None = None # Vertex AI cache resource name for explicit caching
|
||||||
|
|
||||||
# script run context
|
# script run context
|
||||||
script_id: str | None = None
|
script_id: str | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user