Pedro/fix vertex cache leak (#4135)

This commit is contained in:
pedrohsdb
2025-11-29 05:39:05 -08:00
committed by GitHub
parent 2eeca1c699
commit 3f11d44762
4 changed files with 282 additions and 88 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
import base64
import hashlib
import json
import os
import random
@@ -2467,8 +2468,35 @@ class ForgeAgent:
return scraped_page, extract_action_prompt, use_caching
@staticmethod
def _build_extract_action_cache_variant(
verification_code_check: bool,
has_magic_link_page: bool,
complete_criterion: str | None,
) -> str:
"""
Build a short-but-unique cache variant identifier so extract-action prompts that
differ meaningfully (OTP, magic link flows, complete criteria) do not reuse the
same Vertex cache object.
"""
variant_parts: list[str] = []
if verification_code_check:
variant_parts.append("vc")
if has_magic_link_page:
variant_parts.append("ml")
if complete_criterion:
normalized = " ".join(complete_criterion.split())
digest = hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:6]
variant_parts.append(f"cc{digest}")
return "-".join(variant_parts) if variant_parts else "std"
async def _create_vertex_cache_for_task(
self, task: Task, static_prompt: str, context: SkyvernContext, llm_key_override: str | None
self,
task: Task,
static_prompt: str,
context: SkyvernContext,
llm_key_override: str | None,
prompt_variant: str | None = None,
) -> None:
"""
Create a Vertex AI cache for the task's static prompt.
@@ -2479,9 +2507,9 @@ class ForgeAgent:
task: The task to create cache for
static_prompt: The static prompt content to cache
context: The Skyvern context to store the cache name in
llm_key_override: Optional override when we explicitly pick an LLM key
prompt_variant: Cache variant identifier (std/vc/ml/etc.)
"""
# Early return if task doesn't have an llm_key
# This should not happen given the guard at the call site, but being defensive
resolved_llm_key = llm_key_override or task.llm_key
if not resolved_llm_key:
@@ -2491,17 +2519,20 @@ class ForgeAgent:
)
return
cache_variant = prompt_variant or "std"
try:
LOG.info(
"Attempting Vertex AI cache creation",
task_id=task.task_id,
llm_key=resolved_llm_key,
cache_variant=cache_variant,
)
cache_manager = get_cache_manager()
# Use llm_key as cache_key so all tasks with the same model share the same cache
# This maximizes cache reuse and reduces cache storage costs
cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}-{resolved_llm_key}"
variant_suffix = f"-{cache_variant}" if cache_variant else ""
cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}{variant_suffix}-{resolved_llm_key}"
# Get the actual model name from LLM config to ensure correct format
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
@@ -2565,8 +2596,10 @@ class ForgeAgent:
ttl_seconds=3600, # 1 hour
)
# Store cache resource name in context
# Store cache metadata in context
context.vertex_cache_name = cache_data["name"]
context.vertex_cache_key = cache_key
context.vertex_cache_variant = cache_variant
LOG.info(
"Created Vertex AI cache for task",
@@ -2574,6 +2607,7 @@ class ForgeAgent:
cache_key=cache_key,
cache_name=cache_data["name"],
model_name=model_name,
cache_variant=cache_variant,
)
except Exception as e:
LOG.warning(
@@ -2653,7 +2687,7 @@ class ForgeAgent:
# Check if prompt caching is enabled for extract-action
use_caching = False
prompt_caching_settings = LLMAPIHandlerFactory._prompt_caching_settings or {}
prompt_caching_settings = await self._get_prompt_caching_settings(context)
effective_llm_key = task.llm_key
if not effective_llm_key:
handler_for_key = LLMAPIHandlerFactory.get_override_llm_api_handler(
@@ -2701,6 +2735,11 @@ class ForgeAgent:
"parse_select_feature_enabled": context.enable_parse_select_in_extract,
"has_magic_link_page": context.has_magic_link_page(task.task_id),
}
cache_variant = self._build_extract_action_cache_variant(
verification_code_check=verification_code_check,
has_magic_link_page=context.has_magic_link_page(task.task_id),
complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None,
)
static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs)
dynamic_prompt = prompt_engine.load_prompt(
f"{template}-dynamic",
@@ -2718,7 +2757,13 @@ class ForgeAgent:
# Create Vertex AI cache for Gemini models
if effective_llm_key and "GEMINI" in effective_llm_key:
await self._create_vertex_cache_for_task(task, static_prompt, context, effective_llm_key)
await self._create_vertex_cache_for_task(
task,
static_prompt,
context,
effective_llm_key,
prompt_variant=cache_variant,
)
combined_prompt = f"{static_prompt.rstrip()}\n\n{dynamic_prompt.lstrip()}"
@@ -2726,6 +2771,7 @@ class ForgeAgent:
"Using cached prompt",
task_id=task.task_id,
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
cache_variant=cache_variant,
)
return combined_prompt, use_caching
@@ -2755,6 +2801,55 @@ class ForgeAgent:
return full_prompt, use_caching
async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]:
"""
Resolve prompt caching settings for the current run.
We prefer explicit overrides via LLMAPIHandlerFactory.set_prompt_caching_settings(), which
are mostly used by scripts/tests. When no override exists, evaluate the PostHog experiment
once per context and cache the result on the context to avoid repeated lookups.
"""
if LLMAPIHandlerFactory._prompt_caching_settings is not None:
return LLMAPIHandlerFactory._prompt_caching_settings
if context.prompt_caching_settings is not None:
return context.prompt_caching_settings
distinct_id = context.run_id or context.workflow_run_id or context.task_id
organization_id = context.organization_id
context.prompt_caching_settings = {}
if not distinct_id or not organization_id:
return context.prompt_caching_settings
try:
enabled = await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
"PROMPT_CACHING_OPTIMIZATION",
distinct_id,
properties={"organization_id": organization_id},
)
except Exception as exc:
LOG.warning(
"Failed to evaluate prompt caching experiment; defaulting to disabled",
distinct_id=distinct_id,
organization_id=organization_id,
error=str(exc),
)
return context.prompt_caching_settings
if enabled:
context.prompt_caching_settings = {
EXTRACT_ACTION_PROMPT_NAME: True,
EXTRACT_ACTION_TEMPLATE: True,
}
LOG.info(
"Prompt caching optimization enabled",
distinct_id=distinct_id,
organization_id=organization_id,
)
return context.prompt_caching_settings
def _should_process_totp(self, scraped_page: ScrapedPage | None) -> bool:
"""Detect TOTP pages by checking for multiple input fields or verification keywords."""
if not scraped_page: