fix: prevent Vertex cache contamination across different prompt templates (#4183)

This commit is contained in:
pedrohsdb
2025-12-03 11:13:27 -08:00
committed by GitHub
parent 0ed512e4b8
commit ce01f2cb35
2 changed files with 35 additions and 18 deletions

View File

@@ -143,6 +143,7 @@ class SpeculativePlan:
use_caching: bool
llm_json_response: dict[str, Any] | None
llm_metadata: SpeculativeLLMMetadata | None = None
prompt_name: str = "extract-actions"
class ActionLinkedNode:
@@ -945,11 +946,13 @@ class ForgeAgent:
json_response = speculative_plan.llm_json_response
reuse_speculative_llm_response = json_response is not None
speculative_llm_metadata = speculative_plan.llm_metadata
prompt_name = speculative_plan.prompt_name
else:
(
scraped_page,
extract_action_prompt,
use_caching,
prompt_name,
) = await self.build_and_record_step_prompt(
task,
step,
@@ -1014,7 +1017,7 @@ class ForgeAgent:
if not reuse_speculative_llm_response:
json_response = await llm_api_handler(
prompt=extract_action_prompt,
prompt_name="extract-actions",
prompt_name=prompt_name,
step=step,
screenshots=scraped_page.screenshots,
)
@@ -1783,7 +1786,7 @@ class ForgeAgent:
try:
next_step.is_speculative = True
scraped_page, extract_action_prompt, use_caching = await self.build_and_record_step_prompt(
scraped_page, extract_action_prompt, use_caching, prompt_name = await self.build_and_record_step_prompt(
task,
next_step,
browser_state,
@@ -1798,7 +1801,7 @@ class ForgeAgent:
llm_json_response = await llm_api_handler(
prompt=extract_action_prompt,
prompt_name="extract-actions",
prompt_name=prompt_name,
step=next_step,
screenshots=scraped_page.screenshots,
)
@@ -1821,6 +1824,7 @@ class ForgeAgent:
use_caching=use_caching,
llm_json_response=llm_json_response,
llm_metadata=metadata_copy,
prompt_name=prompt_name,
)
except Exception:
LOG.warning(
@@ -2288,7 +2292,7 @@ class ForgeAgent:
engine: RunEngine,
*,
persist_artifacts: bool = True,
) -> tuple[ScrapedPage, str, bool]:
) -> tuple[ScrapedPage, str, bool, str]:
# Check if we have pre-scraped data from parallel verification optimization
context = skyvern_context.current()
scraped_page: ScrapedPage | None = None
@@ -2429,8 +2433,9 @@ class ForgeAgent:
workflow_run_id=task.workflow_run_id,
)
extract_action_prompt = ""
prompt_name = EXTRACT_ACTION_PROMPT_NAME # Default; overwritten below for non-CUA engines
if engine not in CUA_ENGINES:
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
extract_action_prompt, use_caching, prompt_name = await self._build_extract_action_prompt(
task,
step,
browser_state,
@@ -2466,7 +2471,7 @@ class ForgeAgent:
data=element_tree_in_prompt.encode(),
)
return scraped_page, extract_action_prompt, use_caching
return scraped_page, extract_action_prompt, use_caching, prompt_name
@staticmethod
def _build_extract_action_cache_variant(
@@ -2625,7 +2630,7 @@ class ForgeAgent:
scraped_page: ScrapedPage,
verification_code_check: bool = False,
expire_verification_code: bool = False,
) -> tuple[str, bool]:
) -> tuple[str, bool, str]:
actions_and_results_str = await self._get_action_results(task)
# Generate the extract action prompt
@@ -2682,8 +2687,10 @@ class ForgeAgent:
context = skyvern_context.ensure_context()
# Reset cached prompt by default; we will set it below if caching is enabled.
# Reset cached prompt and cache reference by default; we will set them below if caching is enabled.
# This prevents extract-action cache from being attached to other prompts like decisive-criterion-validate.
context.cached_static_prompt = None
context.vertex_cache_name = None
# Check if prompt caching is enabled for extract-action
use_caching = False
@@ -2773,7 +2780,9 @@ class ForgeAgent:
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
cache_variant=cache_variant,
)
return combined_prompt, use_caching
# Map template to prompt_name for logging/caching guards
prompt_name = EXTRACT_ACTION_PROMPT_NAME if template == EXTRACT_ACTION_TEMPLATE else template
return combined_prompt, use_caching, prompt_name
except Exception as e:
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
@@ -2799,7 +2808,9 @@ class ForgeAgent:
has_magic_link_page=context.has_magic_link_page(task.task_id),
)
return full_prompt, use_caching
# Map template to prompt_name for logging/caching guards
prompt_name = EXTRACT_ACTION_PROMPT_NAME if template == EXTRACT_ACTION_TEMPLATE else template
return full_prompt, use_caching, prompt_name
async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]:
"""
@@ -4347,7 +4358,7 @@ class ForgeAgent:
current_context = skyvern_context.ensure_context()
current_context.totp_codes[task.task_id] = otp_value.value
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
extract_action_prompt, use_caching, prompt_name = await self._build_extract_action_prompt(
task,
step,
browser_state,
@@ -4370,7 +4381,7 @@ class ForgeAgent:
prompt=extract_action_prompt,
step=step,
screenshots=scraped_page.screenshots,
prompt_name="extract-actions",
prompt_name=prompt_name,
)
return json_response

View File

@@ -406,10 +406,13 @@ class LLMAPIHandlerFactory:
return llm_request_json
# Inject context caching system message when available
# IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts
# (e.g., check-user-goal) with the extract-action schema
try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
if (
context_cached_static_prompt
context
and context.cached_static_prompt
and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
@@ -426,7 +429,7 @@ class LLMAPIHandlerFactory:
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"text": context.cached_static_prompt,
}
],
}
@@ -789,10 +792,13 @@ class LLMAPIHandlerFactory:
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
# Inject context caching system message when available
# IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts
# (e.g., check-user-goal) with the extract-action schema
try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
if (
context_cached_static_prompt
context
and context.cached_static_prompt
and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
@@ -809,7 +815,7 @@ class LLMAPIHandlerFactory:
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"text": context.cached_static_prompt,
}
],
}