Pedro/fix explicit caching vertex api (#3933)

This commit is contained in:
pedrohsdb
2025-11-06 14:47:58 -08:00
committed by GitHub
parent d2f4e27940
commit 44528cbd38
4 changed files with 340 additions and 40 deletions

View File

@@ -3,6 +3,7 @@ import base64
import json
import os
import random
import re
import string
from asyncio.exceptions import CancelledError
from datetime import UTC, datetime
@@ -71,11 +72,14 @@ from skyvern.forge.sdk.api.files import (
wait_for_download_finished,
)
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.exceptions import LLM_PROVIDER_ERROR_RETRYABLE_TASK_TYPE, LLM_PROVIDER_ERROR_TYPE
from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller
from skyvern.forge.sdk.api.llm.vertex_cache_manager import get_cache_manager
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
from skyvern.forge.sdk.models import Step, StepStatus
@@ -2149,6 +2153,108 @@ class ForgeAgent:
return scraped_page, extract_action_prompt, use_caching
async def _create_vertex_cache_for_task(self, task: Task, static_prompt: str, context: SkyvernContext) -> None:
"""
Create a Vertex AI cache for the task's static prompt.
Uses llm_key as cache key to enable cache sharing across tasks with the same model.
Args:
task: The task to create cache for
static_prompt: The static prompt content to cache
context: The Skyvern context to store the cache name in
"""
# Early return if task doesn't have an llm_key
# This should not happen given the guard at the call site, but being defensive
if not task.llm_key:
LOG.warning(
"Cannot create Vertex AI cache without llm_key, skipping cache creation",
task_id=task.task_id,
)
return
try:
cache_manager = get_cache_manager()
# Use llm_key as cache_key so all tasks with the same model share the same cache
# This maximizes cache reuse and reduces cache storage costs
cache_key = f"extract-action-static-{task.llm_key}"
# Get the actual model name from LLM config to ensure correct format
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
model_name = "gemini-2.5-flash" # Default
try:
llm_config = LLMConfigRegistry.get_config(task.llm_key)
extracted_name = None
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
if hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str):
if "vertex_ai/" in llm_config.model_name:
# Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash"
extracted_name = llm_config.model_name.split("/")[-1]
elif llm_config.model_name.startswith("gemini-"):
# Already in correct format
extracted_name = llm_config.model_name
# For router/fallback configs, extract from api_base or infer from key name
if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params:
params = llm_config.litellm_params
api_base = getattr(params, "api_base", None)
if api_base and isinstance(api_base, str) and "/models/" in api_base:
# Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash"
extracted_name = api_base.split("/models/")[-1]
# For router configs without api_base, infer from the llm_key itself
if not extracted_name:
# Extract version from llm_key (e.g., VERTEX_GEMINI_1_5_FLASH -> "1_5" or VERTEX_GEMINI_2.5_FLASH -> "2.5")
# Pattern: GEMINI_{version}_{flavor} where version can use dots, underscores, or dashes
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", task.llm_key, re.IGNORECASE)
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
# Determine flavor
if "_PRO_" in task.llm_key or task.llm_key.endswith("_PRO"):
extracted_name = f"gemini-{version}-pro"
elif "_FLASH_LITE_" in task.llm_key or task.llm_key.endswith("_FLASH_LITE"):
extracted_name = f"gemini-{version}-flash-lite"
else:
# Default to flash flavor
extracted_name = f"gemini-{version}-flash"
if extracted_name:
model_name = extracted_name
except Exception as e:
LOG.debug("Failed to extract model name from config, using default", error=str(e))
# Create cache for this task
# Use asyncio.to_thread to offload blocking HTTP request (requests.post)
# This prevents freezing the event loop during cache creation
cache_data = await asyncio.to_thread(
cache_manager.create_cache,
model_name=model_name,
static_content=static_prompt,
cache_key=cache_key,
ttl_seconds=3600, # 1 hour
)
# Store cache resource name in context
context.vertex_cache_name = cache_data["name"]
LOG.info(
"Created Vertex AI cache for task",
task_id=task.task_id,
cache_key=cache_key,
cache_name=cache_data["name"],
model_name=model_name,
)
except Exception as e:
LOG.warning(
"Failed to create Vertex AI cache, proceeding without caching",
task_id=task.task_id,
error=str(e),
exc_info=True,
)
async def _build_extract_action_prompt(
self,
task: Task,
@@ -2243,6 +2349,11 @@ class ForgeAgent:
# Store static prompt for caching and return dynamic prompt
context.cached_static_prompt = static_prompt
use_caching = True
# Create Vertex AI cache for Gemini models
if task.llm_key and "GEMINI" in task.llm_key:
await self._create_vertex_cache_for_task(task, static_prompt, context)
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
return dynamic_prompt, use_caching
@@ -2524,6 +2635,9 @@ class ForgeAgent:
raise TaskNotFound(task_id=task.task_id) from e
task = refreshed_task
# Caches expire based on TTL (1 hour) or can be cleaned up via scheduled job
# This allows multiple tasks with the same llm_key to share the same cache
# log the task status as an event
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})