Pedro/fix explicit caching vertex api (#3933)
This commit is contained in:
@@ -3,6 +3,7 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
from asyncio.exceptions import CancelledError
|
||||
from datetime import UTC, datetime
|
||||
@@ -71,11 +72,14 @@ from skyvern.forge.sdk.api.files import (
|
||||
wait_for_download_finished,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory, LLMCaller, LLMCallerManager
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLM_PROVIDER_ERROR_RETRYABLE_TASK_TYPE, LLM_PROVIDER_ERROR_TYPE
|
||||
from skyvern.forge.sdk.api.llm.ui_tars_llm_caller import UITarsLLMCaller
|
||||
from skyvern.forge.sdk.api.llm.vertex_cache_manager import get_cache_manager
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_signature
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.enums import TaskType
|
||||
from skyvern.forge.sdk.log_artifacts import save_step_logs, save_task_logs
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
@@ -2149,6 +2153,108 @@ class ForgeAgent:
|
||||
|
||||
return scraped_page, extract_action_prompt, use_caching
|
||||
|
||||
async def _create_vertex_cache_for_task(self, task: Task, static_prompt: str, context: SkyvernContext) -> None:
|
||||
"""
|
||||
Create a Vertex AI cache for the task's static prompt.
|
||||
|
||||
Uses llm_key as cache key to enable cache sharing across tasks with the same model.
|
||||
|
||||
Args:
|
||||
task: The task to create cache for
|
||||
static_prompt: The static prompt content to cache
|
||||
context: The Skyvern context to store the cache name in
|
||||
"""
|
||||
# Early return if task doesn't have an llm_key
|
||||
# This should not happen given the guard at the call site, but being defensive
|
||||
if not task.llm_key:
|
||||
LOG.warning(
|
||||
"Cannot create Vertex AI cache without llm_key, skipping cache creation",
|
||||
task_id=task.task_id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# Use llm_key as cache_key so all tasks with the same model share the same cache
|
||||
# This maximizes cache reuse and reduces cache storage costs
|
||||
cache_key = f"extract-action-static-{task.llm_key}"
|
||||
|
||||
# Get the actual model name from LLM config to ensure correct format
|
||||
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
|
||||
model_name = "gemini-2.5-flash" # Default
|
||||
|
||||
try:
|
||||
llm_config = LLMConfigRegistry.get_config(task.llm_key)
|
||||
extracted_name = None
|
||||
|
||||
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
|
||||
if hasattr(llm_config, "model_name") and isinstance(llm_config.model_name, str):
|
||||
if "vertex_ai/" in llm_config.model_name:
|
||||
# Direct Vertex config: "vertex_ai/gemini-2.5-flash" -> "gemini-2.5-flash"
|
||||
extracted_name = llm_config.model_name.split("/")[-1]
|
||||
elif llm_config.model_name.startswith("gemini-"):
|
||||
# Already in correct format
|
||||
extracted_name = llm_config.model_name
|
||||
|
||||
# For router/fallback configs, extract from api_base or infer from key name
|
||||
if not extracted_name and hasattr(llm_config, "litellm_params") and llm_config.litellm_params:
|
||||
params = llm_config.litellm_params
|
||||
api_base = getattr(params, "api_base", None)
|
||||
if api_base and isinstance(api_base, str) and "/models/" in api_base:
|
||||
# Extract from URL: .../models/gemini-2.5-flash -> "gemini-2.5-flash"
|
||||
extracted_name = api_base.split("/models/")[-1]
|
||||
|
||||
# For router configs without api_base, infer from the llm_key itself
|
||||
if not extracted_name:
|
||||
# Extract version from llm_key (e.g., VERTEX_GEMINI_1_5_FLASH -> "1_5" or VERTEX_GEMINI_2.5_FLASH -> "2.5")
|
||||
# Pattern: GEMINI_{version}_{flavor} where version can use dots, underscores, or dashes
|
||||
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", task.llm_key, re.IGNORECASE)
|
||||
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
|
||||
|
||||
# Determine flavor
|
||||
if "_PRO_" in task.llm_key or task.llm_key.endswith("_PRO"):
|
||||
extracted_name = f"gemini-{version}-pro"
|
||||
elif "_FLASH_LITE_" in task.llm_key or task.llm_key.endswith("_FLASH_LITE"):
|
||||
extracted_name = f"gemini-{version}-flash-lite"
|
||||
else:
|
||||
# Default to flash flavor
|
||||
extracted_name = f"gemini-{version}-flash"
|
||||
|
||||
if extracted_name:
|
||||
model_name = extracted_name
|
||||
except Exception as e:
|
||||
LOG.debug("Failed to extract model name from config, using default", error=str(e))
|
||||
|
||||
# Create cache for this task
|
||||
# Use asyncio.to_thread to offload blocking HTTP request (requests.post)
|
||||
# This prevents freezing the event loop during cache creation
|
||||
cache_data = await asyncio.to_thread(
|
||||
cache_manager.create_cache,
|
||||
model_name=model_name,
|
||||
static_content=static_prompt,
|
||||
cache_key=cache_key,
|
||||
ttl_seconds=3600, # 1 hour
|
||||
)
|
||||
|
||||
# Store cache resource name in context
|
||||
context.vertex_cache_name = cache_data["name"]
|
||||
|
||||
LOG.info(
|
||||
"Created Vertex AI cache for task",
|
||||
task_id=task.task_id,
|
||||
cache_key=cache_key,
|
||||
cache_name=cache_data["name"],
|
||||
model_name=model_name,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(
|
||||
"Failed to create Vertex AI cache, proceeding without caching",
|
||||
task_id=task.task_id,
|
||||
error=str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _build_extract_action_prompt(
|
||||
self,
|
||||
task: Task,
|
||||
@@ -2243,6 +2349,11 @@ class ForgeAgent:
|
||||
# Store static prompt for caching and return dynamic prompt
|
||||
context.cached_static_prompt = static_prompt
|
||||
use_caching = True
|
||||
|
||||
# Create Vertex AI cache for Gemini models
|
||||
if task.llm_key and "GEMINI" in task.llm_key:
|
||||
await self._create_vertex_cache_for_task(task, static_prompt, context)
|
||||
|
||||
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
|
||||
return dynamic_prompt, use_caching
|
||||
|
||||
@@ -2524,6 +2635,9 @@ class ForgeAgent:
|
||||
raise TaskNotFound(task_id=task.task_id) from e
|
||||
task = refreshed_task
|
||||
|
||||
# Caches expire based on TTL (1 hour) or can be cleaned up via scheduled job
|
||||
# This allows multiple tasks with the same llm_key to share the same cache
|
||||
|
||||
# log the task status as an event
|
||||
analytics.capture("skyvern-oss-agent-task-status", {"status": task.status})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user