use explicit vertex credentials for cache manager (#4039)

This commit is contained in:
pedrohsdb
2025-11-19 17:05:49 -08:00
committed by GitHub
parent d5a7485d45
commit bc6d7affd5
2 changed files with 38 additions and 11 deletions

View File

@@ -673,12 +673,6 @@ class LLMAPIHandlerFactory:
# Add Vertex AI cache reference only for the intended cached prompt # Add Vertex AI cache reference only for the intended cached prompt
vertex_cache_attached = False vertex_cache_attached = False
cache_resource_name = getattr(context, "vertex_cache_name", None) cache_resource_name = getattr(context, "vertex_cache_name", None)
LOG.info(
"Vertex cache attachment check",
cache_resource_name=cache_resource_name,
prompt_name=prompt_name,
use_prompt_caching=getattr(context, "use_prompt_caching", None) if context else None,
)
if ( if (
cache_resource_name cache_resource_name
and prompt_name == EXTRACT_ACTION_PROMPT_NAME and prompt_name == EXTRACT_ACTION_PROMPT_NAME

View File

@@ -8,13 +8,16 @@ Unlike the Anthropic-style cache_control markers, Vertex AI requires:
3. Referencing that cache name in subsequent requests 3. Referencing that cache name in subsequent requests
""" """
import json
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
import google.auth import google.auth
import requests import requests
import structlog import structlog
from google.auth.credentials import Credentials
from google.auth.transport.requests import Request from google.auth.transport.requests import Request
from google.oauth2 import service_account
from skyvern.config import settings from skyvern.config import settings
@@ -29,7 +32,7 @@ class VertexCacheManager:
unlike implicit caching which requires exact prompt matches. unlike implicit caching which requires exact prompt matches.
""" """
def __init__(self, project_id: str, location: str = "global"): def __init__(self, project_id: str, location: str = "global", credentials_json: str | None = None):
self.project_id = project_id self.project_id = project_id
self.location = location self.location = location
# Use regional endpoint for non-global locations, global endpoint for global # Use regional endpoint for non-global locations, global endpoint for global
@@ -38,13 +41,39 @@ class VertexCacheManager:
else: else:
self.api_endpoint = f"{location}-aiplatform.googleapis.com" self.api_endpoint = f"{location}-aiplatform.googleapis.com"
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
self._scopes = ["https://www.googleapis.com/auth/cloud-platform"]
self._default_credentials = None
self._service_account_credentials = None
self._service_account_info: dict[str, Any] | None = None
if credentials_json:
try:
self._service_account_info = json.loads(credentials_json)
except Exception as exc: # noqa: BLE001
LOG.warning("Failed to parse Vertex credentials JSON, falling back to ADC", error=str(exc))
def _get_access_token(self) -> str: def _get_access_token(self) -> str:
"""Get Google Cloud access token for API calls.""" """Get Google Cloud access token for API calls."""
try: try:
# Try to use default credentials credentials: Credentials | None = None
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) if self._service_account_info:
credentials.refresh(Request()) if not self._service_account_credentials:
self._service_account_credentials = service_account.Credentials.from_service_account_info(
self._service_account_info,
scopes=self._scopes,
)
credentials = self._service_account_credentials
else:
if not self._default_credentials:
self._default_credentials, _ = google.auth.default(scopes=self._scopes)
credentials = self._default_credentials
if credentials is None:
raise RuntimeError("Unable to initialize Google credentials for Vertex cache manager")
if not credentials.valid or credentials.expired:
credentials.refresh(Request())
return credentials.token return credentials.token
except Exception as e: except Exception as e:
LOG.error("Failed to get access token", error=str(e)) LOG.error("Failed to get access token", error=str(e))
@@ -197,7 +226,11 @@ def get_cache_manager() -> VertexCacheManager:
# Default to "global" to match the model configs in cloud/__init__.py # Default to "global" to match the model configs in cloud/__init__.py
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching) # Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
location = settings.VERTEX_LOCATION or "global" location = settings.VERTEX_LOCATION or "global"
_global_cache_manager = VertexCacheManager(project_id, location) _global_cache_manager = VertexCacheManager(
project_id=project_id,
location=location,
credentials_json=settings.VERTEX_CREDENTIALS,
)
LOG.info("Created global cache manager", project_id=project_id, location=location) LOG.info("Created global cache manager", project_id=project_id, location=location)
return _global_cache_manager return _global_cache_manager