Pedro/vertex cache minimal fix (#3981)

This commit is contained in:
pedrohsdb
2025-11-12 10:40:52 -08:00
committed by GitHub
parent e8e8481f78
commit d88ca1ca27
3 changed files with 114 additions and 40 deletions

View File

@@ -132,6 +132,10 @@ from skyvern.webeye.utils.page import SkyvernFrame
LOG = structlog.get_logger()
EXTRACT_ACTION_TEMPLATE = "extract-action"
EXTRACT_ACTION_PROMPT_NAME = "extract-actions"
EXTRACT_ACTION_CACHE_KEY_PREFIX = f"{EXTRACT_ACTION_TEMPLATE}-static"
class ActionLinkedNode:
def __init__(self, action: Action) -> None:
@@ -2272,7 +2276,9 @@ class ForgeAgent:
return scraped_page, extract_action_prompt, use_caching
async def _create_vertex_cache_for_task(self, task: Task, static_prompt: str, context: SkyvernContext) -> None:
async def _create_vertex_cache_for_task(
self, task: Task, static_prompt: str, context: SkyvernContext, llm_key_override: str | None
) -> None:
"""
Create a Vertex AI cache for the task's static prompt.
@@ -2285,7 +2291,9 @@ class ForgeAgent:
"""
# Early return if task doesn't have an llm_key
# This should not happen given the guard at the call site, but being defensive
if not task.llm_key:
resolved_llm_key = llm_key_override or task.llm_key
if not resolved_llm_key:
LOG.warning(
"Cannot create Vertex AI cache without llm_key, skipping cache creation",
task_id=task.task_id,
@@ -2293,18 +2301,23 @@ class ForgeAgent:
return
try:
LOG.info(
"Attempting Vertex AI cache creation",
task_id=task.task_id,
llm_key=resolved_llm_key,
)
cache_manager = get_cache_manager()
# Use llm_key as cache_key so all tasks with the same model share the same cache
# This maximizes cache reuse and reduces cache storage costs
cache_key = f"extract-action-static-{task.llm_key}"
cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}-{resolved_llm_key}"
# Get the actual model name from LLM config to ensure correct format
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
model_name = "gemini-2.5-flash" # Default
try:
llm_config = LLMConfigRegistry.get_config(task.llm_key)
llm_config = LLMConfigRegistry.get_config(resolved_llm_key)
extracted_name = None
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
@@ -2328,13 +2341,13 @@ class ForgeAgent:
if not extracted_name:
# Extract version from llm_key (e.g., VERTEX_GEMINI_1_5_FLASH -> "1_5" or VERTEX_GEMINI_2.5_FLASH -> "2.5")
# Pattern: GEMINI_{version}_{flavor} where version can use dots, underscores, or dashes
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", task.llm_key, re.IGNORECASE)
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", resolved_llm_key, re.IGNORECASE)
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
# Determine flavor
if "_PRO_" in task.llm_key or task.llm_key.endswith("_PRO"):
if "_PRO_" in resolved_llm_key or resolved_llm_key.endswith("_PRO"):
extracted_name = f"gemini-{version}-pro"
elif "_FLASH_LITE_" in task.llm_key or task.llm_key.endswith("_FLASH_LITE"):
elif "_FLASH_LITE_" in resolved_llm_key or resolved_llm_key.endswith("_FLASH_LITE"):
extracted_name = f"gemini-{version}-flash-lite"
else:
# Default to flash flavor
@@ -2345,6 +2358,11 @@ class ForgeAgent:
except Exception as e:
LOG.debug("Failed to extract model name from config, using default", error=str(e))
# Normalize model name to the canonical Vertex identifier (e.g., gemini-2.5-pro)
match = re.search(r"(gemini-\d+(?:\.\d+)?-(?:flash-lite|flash|pro))", model_name, re.IGNORECASE)
if match:
model_name = match.group(1).lower()
# Create cache for this task
# Use asyncio.to_thread to offload blocking HTTP request (requests.post)
# This prevents freezing the event loop during cache creation
@@ -2395,11 +2413,12 @@ class ForgeAgent:
final_navigation_payload = self._build_navigation_payload(
task, expire_verification_code=expire_verification_code, step=step, scraped_page=scraped_page
)
navigation_payload_str = json.dumps(final_navigation_payload)
task_type = task.task_type if task.task_type else TaskType.general
template = ""
if task_type == TaskType.general:
template = "extract-action"
template = EXTRACT_ACTION_TEMPLATE
elif task_type == TaskType.validation:
template = "decisive-criterion-validate"
elif task_type == TaskType.action:
@@ -2438,43 +2457,72 @@ class ForgeAgent:
context = skyvern_context.ensure_context()
# Reset cached prompt by default; we will set it below if caching is enabled.
context.cached_static_prompt = None
# Check if prompt caching is enabled for extract-action
use_caching = False
if (
template == "extract-action"
and LLMAPIHandlerFactory._prompt_caching_settings
and LLMAPIHandlerFactory._prompt_caching_settings.get("extract-action", False)
):
prompt_caching_settings = LLMAPIHandlerFactory._prompt_caching_settings or {}
effective_llm_key = task.llm_key
if not effective_llm_key:
handler_for_key = LLMAPIHandlerFactory.get_override_llm_api_handler(
task.llm_key, default=app.LLM_API_HANDLER
)
effective_llm_key = getattr(handler_for_key, "llm_key", None)
cache_enabled = prompt_caching_settings.get(EXTRACT_ACTION_PROMPT_NAME) or prompt_caching_settings.get(
EXTRACT_ACTION_TEMPLATE
)
LOG.info(
"Extract-action prompt caching evaluation",
template=template,
cache_enabled=cache_enabled,
prompt_caching_settings=prompt_caching_settings,
task_llm_key=task.llm_key,
effective_llm_key=effective_llm_key,
)
if template == EXTRACT_ACTION_TEMPLATE and cache_enabled:
try:
# Try to load split templates for caching
static_prompt = prompt_engine.load_prompt(f"{template}-static")
dynamic_prompt = prompt_engine.load_prompt(
f"{template}-dynamic",
navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(final_navigation_payload),
starting_url=starting_url,
current_url=current_url,
data_extraction_goal=task.data_extraction_goal,
action_history=actions_and_results_str,
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
local_datetime=datetime.now(context.tz_info).isoformat(),
verification_code_check=verification_code_check,
complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None,
terminate_criterion=task.terminate_criterion.strip() if task.terminate_criterion else None,
parse_select_feature_enabled=context.enable_parse_select_in_extract,
has_magic_link_page=context.has_magic_link_page(task.task_id),
)
prompt_kwargs = {
"navigation_goal": navigation_goal,
"navigation_payload_str": navigation_payload_str,
"starting_url": starting_url,
"current_url": current_url,
"data_extraction_goal": task.data_extraction_goal,
"action_history": actions_and_results_str,
"error_code_mapping_str": (
json.dumps(task.error_code_mapping) if task.error_code_mapping else None
),
"local_datetime": datetime.now(context.tz_info).isoformat(),
"verification_code_check": verification_code_check,
"complete_criterion": task.complete_criterion.strip() if task.complete_criterion else None,
"terminate_criterion": task.terminate_criterion.strip() if task.terminate_criterion else None,
"parse_select_feature_enabled": context.enable_parse_select_in_extract,
"has_magic_link_page": context.has_magic_link_page(task.task_id),
}
static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs)
dynamic_prompt = prompt_engine.load_prompt(f"{template}-dynamic", **prompt_kwargs)
# Store static prompt for caching and return dynamic prompt
# Store static prompt for caching and continue sending it alongside the dynamic section.
# Vertex explicit caching expects the static content to still be present in the request so the
# first call succeeds even if the cache is cold. The cached reference simply lets the service
# reuse the static portion internally.
context.cached_static_prompt = static_prompt
context.use_prompt_caching = True
use_caching = True
# Create Vertex AI cache for Gemini models
if task.llm_key and "GEMINI" in task.llm_key:
await self._create_vertex_cache_for_task(task, static_prompt, context)
if effective_llm_key and "GEMINI" in effective_llm_key:
await self._create_vertex_cache_for_task(task, static_prompt, context, effective_llm_key)
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
return dynamic_prompt, use_caching
combined_prompt = f"{static_prompt.rstrip()}\n\n{dynamic_prompt.lstrip()}"
LOG.info(
"Using cached prompt",
task_id=task.task_id,
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
)
return combined_prompt, use_caching
except Exception as e:
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
@@ -2486,7 +2534,7 @@ class ForgeAgent:
prompt_engine=prompt_engine,
template_name=template,
navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(final_navigation_payload),
navigation_payload_str=navigation_payload_str,
starting_url=starting_url,
current_url=current_url,
data_extraction_goal=task.data_extraction_goal,