use explicit vertex credentials for cache manager (#4039)
This commit is contained in:
@@ -673,12 +673,6 @@ class LLMAPIHandlerFactory:
|
|||||||
# Add Vertex AI cache reference only for the intended cached prompt
|
# Add Vertex AI cache reference only for the intended cached prompt
|
||||||
vertex_cache_attached = False
|
vertex_cache_attached = False
|
||||||
cache_resource_name = getattr(context, "vertex_cache_name", None)
|
cache_resource_name = getattr(context, "vertex_cache_name", None)
|
||||||
LOG.info(
|
|
||||||
"Vertex cache attachment check",
|
|
||||||
cache_resource_name=cache_resource_name,
|
|
||||||
prompt_name=prompt_name,
|
|
||||||
use_prompt_caching=getattr(context, "use_prompt_caching", None) if context else None,
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
cache_resource_name
|
cache_resource_name
|
||||||
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
|
and prompt_name == EXTRACT_ACTION_PROMPT_NAME
|
||||||
|
|||||||
@@ -8,13 +8,16 @@ Unlike the Anthropic-style cache_control markers, Vertex AI requires:
|
|||||||
3. Referencing that cache name in subsequent requests
|
3. Referencing that cache name in subsequent requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import google.auth
|
import google.auth
|
||||||
import requests
|
import requests
|
||||||
import structlog
|
import structlog
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
from google.auth.transport.requests import Request
|
from google.auth.transport.requests import Request
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
|
|
||||||
@@ -29,7 +32,7 @@ class VertexCacheManager:
|
|||||||
unlike implicit caching which requires exact prompt matches.
|
unlike implicit caching which requires exact prompt matches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, project_id: str, location: str = "global"):
|
def __init__(self, project_id: str, location: str = "global", credentials_json: str | None = None):
|
||||||
self.project_id = project_id
|
self.project_id = project_id
|
||||||
self.location = location
|
self.location = location
|
||||||
# Use regional endpoint for non-global locations, global endpoint for global
|
# Use regional endpoint for non-global locations, global endpoint for global
|
||||||
@@ -38,13 +41,39 @@ class VertexCacheManager:
|
|||||||
else:
|
else:
|
||||||
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
|
self.api_endpoint = f"{location}-aiplatform.googleapis.com"
|
||||||
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
|
self._cache_registry: dict[str, dict[str, Any]] = {} # Maps cache_key -> cache_data
|
||||||
|
self._scopes = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||||
|
self._default_credentials = None
|
||||||
|
self._service_account_credentials = None
|
||||||
|
self._service_account_info: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
if credentials_json:
|
||||||
|
try:
|
||||||
|
self._service_account_info = json.loads(credentials_json)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOG.warning("Failed to parse Vertex credentials JSON, falling back to ADC", error=str(exc))
|
||||||
|
|
||||||
def _get_access_token(self) -> str:
|
def _get_access_token(self) -> str:
|
||||||
"""Get Google Cloud access token for API calls."""
|
"""Get Google Cloud access token for API calls."""
|
||||||
try:
|
try:
|
||||||
# Try to use default credentials
|
credentials: Credentials | None = None
|
||||||
credentials, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
if self._service_account_info:
|
||||||
credentials.refresh(Request())
|
if not self._service_account_credentials:
|
||||||
|
self._service_account_credentials = service_account.Credentials.from_service_account_info(
|
||||||
|
self._service_account_info,
|
||||||
|
scopes=self._scopes,
|
||||||
|
)
|
||||||
|
credentials = self._service_account_credentials
|
||||||
|
else:
|
||||||
|
if not self._default_credentials:
|
||||||
|
self._default_credentials, _ = google.auth.default(scopes=self._scopes)
|
||||||
|
credentials = self._default_credentials
|
||||||
|
|
||||||
|
if credentials is None:
|
||||||
|
raise RuntimeError("Unable to initialize Google credentials for Vertex cache manager")
|
||||||
|
|
||||||
|
if not credentials.valid or credentials.expired:
|
||||||
|
credentials.refresh(Request())
|
||||||
|
|
||||||
return credentials.token
|
return credentials.token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error("Failed to get access token", error=str(e))
|
LOG.error("Failed to get access token", error=str(e))
|
||||||
@@ -197,7 +226,11 @@ def get_cache_manager() -> VertexCacheManager:
|
|||||||
# Default to "global" to match the model configs in cloud/__init__.py
|
# Default to "global" to match the model configs in cloud/__init__.py
|
||||||
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
|
# Can be overridden with VERTEX_LOCATION (e.g., "us-central1" for better caching)
|
||||||
location = settings.VERTEX_LOCATION or "global"
|
location = settings.VERTEX_LOCATION or "global"
|
||||||
_global_cache_manager = VertexCacheManager(project_id, location)
|
_global_cache_manager = VertexCacheManager(
|
||||||
|
project_id=project_id,
|
||||||
|
location=location,
|
||||||
|
credentials_json=settings.VERTEX_CREDENTIALS,
|
||||||
|
)
|
||||||
LOG.info("Created global cache manager", project_id=project_id, location=location)
|
LOG.info("Created global cache manager", project_id=project_id, location=location)
|
||||||
|
|
||||||
return _global_cache_manager
|
return _global_cache_manager
|
||||||
|
|||||||
Reference in New Issue
Block a user