From bc6d7affd53ceb4155da8bb20450f4089040ffce Mon Sep 17 00:00:00 2001 From: pedrohsdb Date: Wed, 19 Nov 2025 17:05:49 -0800 Subject: [PATCH] use explicit vertex credentials for cache manager (#4039) --- .../forge/sdk/api/llm/api_handler_factory.py | 6 --- .../forge/sdk/api/llm/vertex_cache_manager.py | 43 ++++++++++++++++--- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index f16dbe3a..54c0329d 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -673,12 +673,6 @@ class LLMAPIHandlerFactory: # Add Vertex AI cache reference only for the intended cached prompt vertex_cache_attached = False cache_resource_name = getattr(context, "vertex_cache_name", None) - LOG.info( - "Vertex cache attachment check", - cache_resource_name=cache_resource_name, - prompt_name=prompt_name, - use_prompt_caching=getattr(context, "use_prompt_caching", None) if context else None, - ) if ( cache_resource_name and prompt_name == EXTRACT_ACTION_PROMPT_NAME diff --git a/skyvern/forge/sdk/api/llm/vertex_cache_manager.py b/skyvern/forge/sdk/api/llm/vertex_cache_manager.py index 5f97c3b1..a8501949 100644 --- a/skyvern/forge/sdk/api/llm/vertex_cache_manager.py +++ b/skyvern/forge/sdk/api/llm/vertex_cache_manager.py @@ -8,13 +8,16 @@ Unlike the Anthropic-style cache_control markers, Vertex AI requires: 3. Referencing that cache name in subsequent requests """ +import json from datetime import datetime from typing import Any import google.auth import requests import structlog +from google.auth.credentials import Credentials from google.auth.transport.requests import Request +from google.oauth2 import service_account from skyvern.config import settings @@ -29,7 +32,7 @@ class VertexCacheManager: unlike implicit caching which requires exact prompt matches. """ - def __init__(self, project_id: str, location: str = "global"): + def __init__(self, project_id: str, location: str = "global", credentials_json: str | None = None): self.project_id = project_id self.location = location # Use regional endpoint for non-global locations, global endpoint for global @@ -38,13 +41,39 @@ class VertexCacheManager: else: self.api_endpoint = f"{location}-aiplatform.googleapis.com" self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data + self._scopes = ["https://www.googleapis.com/auth/cloud-platform"] + self._default_credentials = None + self._service_account_credentials = None + self._service_account_info: dict[str, Any] | None = None + + if credentials_json: + try: + self._service_account_info = json.loads(credentials_json) + except Exception as exc: # noqa: BLE001 + LOG.warning("Failed to parse Vertex credentials JSON, falling back to ADC", error=str(exc)) def _get_access_token(self) -> str: """Get Google Cloud access token for API calls.""" try: - # Try to use default credentials - credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) - credentials.refresh(Request()) + credentials: Credentials | None = None + if self._service_account_info: + if not self._service_account_credentials: + self._service_account_credentials = service_account.Credentials.from_service_account_info( + self._service_account_info, + scopes=self._scopes, + ) + credentials = self._service_account_credentials + else: + if not self._default_credentials: + self._default_credentials, _ = google.auth.default(scopes=self._scopes) + credentials = self._default_credentials + + if credentials is None: + raise RuntimeError("Unable to initialize Google credentials for Vertex cache manager") + + if not credentials.valid or credentials.expired: + credentials.refresh(Request()) + return credentials.token except Exception as e: LOG.error("Failed to get access token", error=str(e)) @@ -197,7 +226,11 @@ def get_cache_manager() -> VertexCacheManager: # Default to "global" to match the model configs in cloud/__init__.py # Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching) location = settings.VERTEX_LOCATION or "global" - _global_cache_manager = VertexCacheManager(project_id, location) + _global_cache_manager = VertexCacheManager( + project_id=project_id, + location=location, + credentials_json=settings.VERTEX_CREDENTIALS, + ) LOG.info("Created global cache manager", project_id=project_id, location=location) return _global_cache_manager