Pedro/vertex cache minimal fix (#3981)
This commit is contained in:
@@ -132,6 +132,10 @@ from skyvern.webeye.utils.page import SkyvernFrame
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
EXTRACT_ACTION_TEMPLATE = "extract-action"
|
||||
EXTRACT_ACTION_PROMPT_NAME = "extract-actions"
|
||||
EXTRACT_ACTION_CACHE_KEY_PREFIX = f"{EXTRACT_ACTION_TEMPLATE}-static"
|
||||
|
||||
|
||||
class ActionLinkedNode:
|
||||
def __init__(self, action: Action) -> None:
|
||||
@@ -2272,7 +2276,9 @@ class ForgeAgent:
|
||||
|
||||
return scraped_page, extract_action_prompt, use_caching
|
||||
|
||||
async def _create_vertex_cache_for_task(self, task: Task, static_prompt: str, context: SkyvernContext) -> None:
|
||||
async def _create_vertex_cache_for_task(
|
||||
self, task: Task, static_prompt: str, context: SkyvernContext, llm_key_override: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a Vertex AI cache for the task's static prompt.
|
||||
|
||||
@@ -2285,7 +2291,9 @@ class ForgeAgent:
|
||||
"""
|
||||
# Early return if task doesn't have an llm_key
|
||||
# This should not happen given the guard at the call site, but being defensive
|
||||
if not task.llm_key:
|
||||
resolved_llm_key = llm_key_override or task.llm_key
|
||||
|
||||
if not resolved_llm_key:
|
||||
LOG.warning(
|
||||
"Cannot create Vertex AI cache without llm_key, skipping cache creation",
|
||||
task_id=task.task_id,
|
||||
@@ -2293,18 +2301,23 @@ class ForgeAgent:
|
||||
return
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Attempting Vertex AI cache creation",
|
||||
task_id=task.task_id,
|
||||
llm_key=resolved_llm_key,
|
||||
)
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# Use llm_key as cache_key so all tasks with the same model share the same cache
|
||||
# This maximizes cache reuse and reduces cache storage costs
|
||||
cache_key = f"extract-action-static-{task.llm_key}"
|
||||
cache_key = f"{EXTRACT_ACTION_CACHE_KEY_PREFIX}-{resolved_llm_key}"
|
||||
|
||||
# Get the actual model name from LLM config to ensure correct format
|
||||
# (e.g., "gemini-2.5-flash" with decimal, not "gemini-2-5-flash")
|
||||
model_name = "gemini-2.5-flash" # Default
|
||||
|
||||
try:
|
||||
llm_config = LLMConfigRegistry.get_config(task.llm_key)
|
||||
llm_config = LLMConfigRegistry.get_config(resolved_llm_key)
|
||||
extracted_name = None
|
||||
|
||||
# Try to extract from model_name if it contains "vertex_ai/" or starts with "gemini-"
|
||||
@@ -2328,13 +2341,13 @@ class ForgeAgent:
|
||||
if not extracted_name:
|
||||
# Extract version from llm_key (e.g., VERTEX_GEMINI_1_5_FLASH -> "1_5" or VERTEX_GEMINI_2.5_FLASH -> "2.5")
|
||||
# Pattern: GEMINI_{version}_{flavor} where version can use dots, underscores, or dashes
|
||||
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", task.llm_key, re.IGNORECASE)
|
||||
version_match = re.search(r"GEMINI[_-](\d+[._-]\d+)", resolved_llm_key, re.IGNORECASE)
|
||||
version = version_match.group(1).replace("_", ".").replace("-", ".") if version_match else "2.5"
|
||||
|
||||
# Determine flavor
|
||||
if "_PRO_" in task.llm_key or task.llm_key.endswith("_PRO"):
|
||||
if "_PRO_" in resolved_llm_key or resolved_llm_key.endswith("_PRO"):
|
||||
extracted_name = f"gemini-{version}-pro"
|
||||
elif "_FLASH_LITE_" in task.llm_key or task.llm_key.endswith("_FLASH_LITE"):
|
||||
elif "_FLASH_LITE_" in resolved_llm_key or resolved_llm_key.endswith("_FLASH_LITE"):
|
||||
extracted_name = f"gemini-{version}-flash-lite"
|
||||
else:
|
||||
# Default to flash flavor
|
||||
@@ -2345,6 +2358,11 @@ class ForgeAgent:
|
||||
except Exception as e:
|
||||
LOG.debug("Failed to extract model name from config, using default", error=str(e))
|
||||
|
||||
# Normalize model name to the canonical Vertex identifier (e.g., gemini-2.5-pro)
|
||||
match = re.search(r"(gemini-\d+(?:\.\d+)?-(?:flash-lite|flash|pro))", model_name, re.IGNORECASE)
|
||||
if match:
|
||||
model_name = match.group(1).lower()
|
||||
|
||||
# Create cache for this task
|
||||
# Use asyncio.to_thread to offload blocking HTTP request (requests.post)
|
||||
# This prevents freezing the event loop during cache creation
|
||||
@@ -2395,11 +2413,12 @@ class ForgeAgent:
|
||||
final_navigation_payload = self._build_navigation_payload(
|
||||
task, expire_verification_code=expire_verification_code, step=step, scraped_page=scraped_page
|
||||
)
|
||||
navigation_payload_str = json.dumps(final_navigation_payload)
|
||||
|
||||
task_type = task.task_type if task.task_type else TaskType.general
|
||||
template = ""
|
||||
if task_type == TaskType.general:
|
||||
template = "extract-action"
|
||||
template = EXTRACT_ACTION_TEMPLATE
|
||||
elif task_type == TaskType.validation:
|
||||
template = "decisive-criterion-validate"
|
||||
elif task_type == TaskType.action:
|
||||
@@ -2438,43 +2457,72 @@ class ForgeAgent:
|
||||
|
||||
context = skyvern_context.ensure_context()
|
||||
|
||||
# Reset cached prompt by default; we will set it below if caching is enabled.
|
||||
context.cached_static_prompt = None
|
||||
|
||||
# Check if prompt caching is enabled for extract-action
|
||||
use_caching = False
|
||||
if (
|
||||
template == "extract-action"
|
||||
and LLMAPIHandlerFactory._prompt_caching_settings
|
||||
and LLMAPIHandlerFactory._prompt_caching_settings.get("extract-action", False)
|
||||
):
|
||||
prompt_caching_settings = LLMAPIHandlerFactory._prompt_caching_settings or {}
|
||||
effective_llm_key = task.llm_key
|
||||
if not effective_llm_key:
|
||||
handler_for_key = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
task.llm_key, default=app.LLM_API_HANDLER
|
||||
)
|
||||
effective_llm_key = getattr(handler_for_key, "llm_key", None)
|
||||
cache_enabled = prompt_caching_settings.get(EXTRACT_ACTION_PROMPT_NAME) or prompt_caching_settings.get(
|
||||
EXTRACT_ACTION_TEMPLATE
|
||||
)
|
||||
LOG.info(
|
||||
"Extract-action prompt caching evaluation",
|
||||
template=template,
|
||||
cache_enabled=cache_enabled,
|
||||
prompt_caching_settings=prompt_caching_settings,
|
||||
task_llm_key=task.llm_key,
|
||||
effective_llm_key=effective_llm_key,
|
||||
)
|
||||
if template == EXTRACT_ACTION_TEMPLATE and cache_enabled:
|
||||
try:
|
||||
# Try to load split templates for caching
|
||||
static_prompt = prompt_engine.load_prompt(f"{template}-static")
|
||||
dynamic_prompt = prompt_engine.load_prompt(
|
||||
f"{template}-dynamic",
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload_str=json.dumps(final_navigation_payload),
|
||||
starting_url=starting_url,
|
||||
current_url=current_url,
|
||||
data_extraction_goal=task.data_extraction_goal,
|
||||
action_history=actions_and_results_str,
|
||||
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
|
||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||
verification_code_check=verification_code_check,
|
||||
complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None,
|
||||
terminate_criterion=task.terminate_criterion.strip() if task.terminate_criterion else None,
|
||||
parse_select_feature_enabled=context.enable_parse_select_in_extract,
|
||||
has_magic_link_page=context.has_magic_link_page(task.task_id),
|
||||
)
|
||||
prompt_kwargs = {
|
||||
"navigation_goal": navigation_goal,
|
||||
"navigation_payload_str": navigation_payload_str,
|
||||
"starting_url": starting_url,
|
||||
"current_url": current_url,
|
||||
"data_extraction_goal": task.data_extraction_goal,
|
||||
"action_history": actions_and_results_str,
|
||||
"error_code_mapping_str": (
|
||||
json.dumps(task.error_code_mapping) if task.error_code_mapping else None
|
||||
),
|
||||
"local_datetime": datetime.now(context.tz_info).isoformat(),
|
||||
"verification_code_check": verification_code_check,
|
||||
"complete_criterion": task.complete_criterion.strip() if task.complete_criterion else None,
|
||||
"terminate_criterion": task.terminate_criterion.strip() if task.terminate_criterion else None,
|
||||
"parse_select_feature_enabled": context.enable_parse_select_in_extract,
|
||||
"has_magic_link_page": context.has_magic_link_page(task.task_id),
|
||||
}
|
||||
static_prompt = prompt_engine.load_prompt(f"{template}-static", **prompt_kwargs)
|
||||
dynamic_prompt = prompt_engine.load_prompt(f"{template}-dynamic", **prompt_kwargs)
|
||||
|
||||
# Store static prompt for caching and return dynamic prompt
|
||||
# Store static prompt for caching and continue sending it alongside the dynamic section.
|
||||
# Vertex explicit caching expects the static content to still be present in the request so the
|
||||
# first call succeeds even if the cache is cold. The cached reference simply lets the service
|
||||
# reuse the static portion internally.
|
||||
context.cached_static_prompt = static_prompt
|
||||
context.use_prompt_caching = True
|
||||
use_caching = True
|
||||
|
||||
# Create Vertex AI cache for Gemini models
|
||||
if task.llm_key and "GEMINI" in task.llm_key:
|
||||
await self._create_vertex_cache_for_task(task, static_prompt, context)
|
||||
if effective_llm_key and "GEMINI" in effective_llm_key:
|
||||
await self._create_vertex_cache_for_task(task, static_prompt, context, effective_llm_key)
|
||||
|
||||
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
|
||||
return dynamic_prompt, use_caching
|
||||
combined_prompt = f"{static_prompt.rstrip()}\n\n{dynamic_prompt.lstrip()}"
|
||||
|
||||
LOG.info(
|
||||
"Using cached prompt",
|
||||
task_id=task.task_id,
|
||||
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
|
||||
)
|
||||
return combined_prompt, use_caching
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
|
||||
@@ -2486,7 +2534,7 @@ class ForgeAgent:
|
||||
prompt_engine=prompt_engine,
|
||||
template_name=template,
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload_str=json.dumps(final_navigation_payload),
|
||||
navigation_payload_str=navigation_payload_str,
|
||||
starting_url=starting_url,
|
||||
current_url=current_url,
|
||||
data_extraction_goal=task.data_extraction_goal,
|
||||
|
||||
Reference in New Issue
Block a user