restore vertex cache credentials (#4050)
This commit is contained in:
@@ -673,12 +673,6 @@ class LLMAPIHandlerFactory:
|
||||
# Add Vertex AI cache reference only for the intended cached prompt
|
||||
vertex_cache_attached = False
|
||||
cache_resource_name = getattr(context, "vertex_cache_name", None)
|
||||
LOG.info(
|
||||
"Vertex cache attachment check",
|
||||
cache_resource_name=cache_resource_name,
|
||||
prompt_name=prompt_name,
|
||||
use_prompt_caching=getattr(context, "use_prompt_caching", None) if context else None,
|
||||
)
|
||||
if (
|
||||
cache_resource_name
|
||||
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
|
||||
|
||||
@@ -8,13 +8,16 @@ Unlike the Anthropic-style cache_control markers, Vertex AI requires:
|
||||
3. Referencing that cache name in subsequent requests
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import google.auth
|
||||
import requests
|
||||
import structlog
|
||||
from google.auth.credentials import Credentials
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
|
||||
from skyvern.config import settings
|
||||
|
||||
@@ -29,7 +32,7 @@ class VertexCacheManager:
|
||||
unlike implicit caching which requires exact prompt matches.
|
||||
"""
|
||||
|
||||
def __init__(self, project_id: str, location: str = "global"):
|
||||
def __init__(self, project_id: str, location: str = "global", credentials_json: str | None = None):
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
# Use regional endpoint for non-global locations, global endpoint for global
|
||||
@@ -38,13 +41,39 @@ class VertexCacheManager:
|
||||
else:
|
||||
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
|
||||
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
|
||||
self._scopes = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
self._default_credentials = None
|
||||
self._service_account_credentials = None
|
||||
self._service_account_info: dict[str, Any] | None = None
|
||||
|
||||
if credentials_json:
|
||||
try:
|
||||
self._service_account_info = json.loads(credentials_json)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOG.warning("Failed to parse Vertex credentials JSON, falling back to ADC", error=str(exc))
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Get Google Cloud access token for API calls."""
|
||||
try:
|
||||
# Try to use default credentials
|
||||
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||
credentials.refresh(Request())
|
||||
credentials: Credentials | None = None
|
||||
if self._service_account_info:
|
||||
if not self._service_account_credentials:
|
||||
self._service_account_credentials = service_account.Credentials.from_service_account_info(
|
||||
self._service_account_info,
|
||||
scopes=self._scopes,
|
||||
)
|
||||
credentials = self._service_account_credentials
|
||||
else:
|
||||
if not self._default_credentials:
|
||||
self._default_credentials, _ = google.auth.default(scopes=self._scopes)
|
||||
credentials = self._default_credentials
|
||||
|
||||
if credentials is None:
|
||||
raise RuntimeError("Unable to initialize Google credentials for Vertex cache manager")
|
||||
|
||||
if not credentials.valid or credentials.expired:
|
||||
credentials.refresh(Request())
|
||||
|
||||
return credentials.token
|
||||
except Exception as e:
|
||||
LOG.error("Failed to get access token", error=str(e))
|
||||
@@ -197,7 +226,11 @@ def get_cache_manager() -> VertexCacheManager:
|
||||
# Default to "global" to match the model configs in cloud/__init__.py
|
||||
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
|
||||
location = settings.VERTEX_LOCATION or "global"
|
||||
_global_cache_manager = VertexCacheManager(project_id, location)
|
||||
_global_cache_manager = VertexCacheManager(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
credentials_json=settings.VERTEX_CREDENTIALS,
|
||||
)
|
||||
LOG.info("Created global cache manager", project_id=project_id, location=location)
|
||||
|
||||
return _global_cache_manager
|
||||
|
||||
Reference in New Issue
Block a user