use explicit vertex credentials for cache manager (#4039)

This commit is contained in:
pedrohsdb
2025-11-19 17:05:49 -08:00
committed by GitHub
parent d5a7485d45
commit bc6d7affd5
2 changed files with 38 additions and 11 deletions

View File

@@ -673,12 +673,6 @@ class LLMAPIHandlerFactory:
# Add Vertex AI cache reference only for the intended cached prompt
vertex_cache_attached = False
cache_resource_name = getattr(context, "vertex_cache_name", None)
LOG.info(
"Vertex cache attachment check",
cache_resource_name=cache_resource_name,
prompt_name=prompt_name,
use_prompt_caching=getattr(context, "use_prompt_caching", None) if context else None,
)
if (
cache_resource_name
and prompt_name == EXTRACT_ACTION_PROMPT_NAME

View File

@@ -8,13 +8,16 @@ Unlike the Anthropic-style cache_control markers, Vertex AI requires:
3. Referencing that cache name in subsequent requests
"""
import json
from datetime import datetime
from typing import Any
import google.auth
import requests
import structlog
from google.auth.credentials import Credentials
from google.auth.transport.requests import Request
from google.oauth2 import service_account
from skyvern.config import settings
@@ -29,7 +32,7 @@ class VertexCacheManager:
unlike implicit caching which requires exact prompt matches.
"""
def __init__(self, project_id: str, location: str = "global"):
def __init__(self, project_id: str, location: str = "global", credentials_json: str | None = None):
self.project_id = project_id
self.location = location
# Use regional endpoint for non-global locations, global endpoint for global
@@ -38,13 +41,39 @@ class VertexCacheManager:
else:
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
self._scopes = ["https://www.googleapis.com/auth/cloud-platform"]
self._default_credentials = None
self._service_account_credentials = None
self._service_account_info: dict[str, Any] | None = None
if credentials_json:
try:
self._service_account_info = json.loads(credentials_json)
except Exception as exc: # noqa: BLE001
LOG.warning("Failed to parse Vertex credentials JSON, falling back to ADC", error=str(exc))
def _get_access_token(self) -> str:
"""Get Google Cloud access token for API calls."""
try:
# Try to use default credentials
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(Request())
credentials: Credentials | None = None
if self._service_account_info:
if not self._service_account_credentials:
self._service_account_credentials = service_account.Credentials.from_service_account_info(
self._service_account_info,
scopes=self._scopes,
)
credentials = self._service_account_credentials
else:
if not self._default_credentials:
self._default_credentials, _ = google.auth.default(scopes=self._scopes)
credentials = self._default_credentials
if credentials is None:
raise RuntimeError("Unable to initialize Google credentials for Vertex cache manager")
if not credentials.valid or credentials.expired:
credentials.refresh(Request())
return credentials.token
except Exception as e:
LOG.error("Failed to get access token", error=str(e))
@@ -197,7 +226,11 @@ def get_cache_manager() -> VertexCacheManager:
# Default to "global" to match the model configs in cloud/__init__.py
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
location = settings.VERTEX_LOCATION or "global"
_global_cache_manager = VertexCacheManager(project_id, location)
_global_cache_manager = VertexCacheManager(
project_id=project_id,
location=location,
credentials_json=settings.VERTEX_CREDENTIALS,
)
LOG.info("Created global cache manager", project_id=project_id, location=location)
return _global_cache_manager