fix: prevent Vertex cache contamination across different prompt templates (#4183)

This commit is contained in:
pedrohsdb
2025-12-03 11:13:27 -08:00
committed by GitHub
parent 0ed512e4b8
commit ce01f2cb35
2 changed files with 35 additions and 18 deletions

View File

@@ -143,6 +143,7 @@ class SpeculativePlan:
use_caching: bool use_caching: bool
llm_json_response: dict[str, Any] | None llm_json_response: dict[str, Any] | None
llm_metadata: SpeculativeLLMMetadata | None = None llm_metadata: SpeculativeLLMMetadata | None = None
prompt_name: str = "extract-actions"
class ActionLinkedNode: class ActionLinkedNode:
@@ -945,11 +946,13 @@ class ForgeAgent:
json_response = speculative_plan.llm_json_response json_response = speculative_plan.llm_json_response
reuse_speculative_llm_response = json_response is not None reuse_speculative_llm_response = json_response is not None
speculative_llm_metadata = speculative_plan.llm_metadata speculative_llm_metadata = speculative_plan.llm_metadata
prompt_name = speculative_plan.prompt_name
else: else:
( (
scraped_page, scraped_page,
extract_action_prompt, extract_action_prompt,
use_caching, use_caching,
prompt_name,
) = await self.build_and_record_step_prompt( ) = await self.build_and_record_step_prompt(
task, task,
step, step,
@@ -1014,7 +1017,7 @@ class ForgeAgent:
if not reuse_speculative_llm_response: if not reuse_speculative_llm_response:
json_response = await llm_api_handler( json_response = await llm_api_handler(
prompt=extract_action_prompt, prompt=extract_action_prompt,
prompt_name="extract-actions", prompt_name=prompt_name,
step=step, step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
) )
@@ -1783,7 +1786,7 @@ class ForgeAgent:
try: try:
next_step.is_speculative = True next_step.is_speculative = True
scraped_page, extract_action_prompt, use_caching = await self.build_and_record_step_prompt( scraped_page, extract_action_prompt, use_caching, prompt_name = await self.build_and_record_step_prompt(
task, task,
next_step, next_step,
browser_state, browser_state,
@@ -1798,7 +1801,7 @@ class ForgeAgent:
llm_json_response = await llm_api_handler( llm_json_response = await llm_api_handler(
prompt=extract_action_prompt, prompt=extract_action_prompt,
prompt_name="extract-actions", prompt_name=prompt_name,
step=next_step, step=next_step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
) )
@@ -1821,6 +1824,7 @@ class ForgeAgent:
use_caching=use_caching, use_caching=use_caching,
llm_json_response=llm_json_response, llm_json_response=llm_json_response,
llm_metadata=metadata_copy, llm_metadata=metadata_copy,
prompt_name=prompt_name,
) )
except Exception: except Exception:
LOG.warning( LOG.warning(
@@ -2288,7 +2292,7 @@ class ForgeAgent:
engine: RunEngine, engine: RunEngine,
*, *,
persist_artifacts: bool = True, persist_artifacts: bool = True,
) -> tuple[ScrapedPage, str, bool]: ) -> tuple[ScrapedPage, str, bool, str]:
# Check if we have pre-scraped data from parallel verification optimization # Check if we have pre-scraped data from parallel verification optimization
context = skyvern_context.current() context = skyvern_context.current()
scraped_page: ScrapedPage | None = None scraped_page: ScrapedPage | None = None
@@ -2429,8 +2433,9 @@ class ForgeAgent:
workflow_run_id=task.workflow_run_id, workflow_run_id=task.workflow_run_id,
) )
extract_action_prompt = "" extract_action_prompt = ""
prompt_name = EXTRACT_ACTION_PROMPT_NAME # Default; overwritten below for non-CUA engines
if engine not in CUA_ENGINES: if engine not in CUA_ENGINES:
extract_action_prompt, use_caching = await self._build_extract_action_prompt( extract_action_prompt, use_caching, prompt_name = await self._build_extract_action_prompt(
task, task,
step, step,
browser_state, browser_state,
@@ -2466,7 +2471,7 @@ class ForgeAgent:
data=element_tree_in_prompt.encode(), data=element_tree_in_prompt.encode(),
) )
return scraped_page, extract_action_prompt, use_caching return scraped_page, extract_action_prompt, use_caching, prompt_name
@staticmethod @staticmethod
def _build_extract_action_cache_variant( def _build_extract_action_cache_variant(
@@ -2625,7 +2630,7 @@ class ForgeAgent:
scraped_page: ScrapedPage, scraped_page: ScrapedPage,
verification_code_check: bool = False, verification_code_check: bool = False,
expire_verification_code: bool = False, expire_verification_code: bool = False,
) -> tuple[str, bool]: ) -> tuple[str, bool, str]:
actions_and_results_str = await self._get_action_results(task) actions_and_results_str = await self._get_action_results(task)
# Generate the extract action prompt # Generate the extract action prompt
@@ -2682,8 +2687,10 @@ class ForgeAgent:
context = skyvern_context.ensure_context() context = skyvern_context.ensure_context()
# Reset cached prompt by default; we will set it below if caching is enabled. # Reset cached prompt and cache reference by default; we will set them below if caching is enabled.
# This prevents extract-action cache from being attached to other prompts like decisive-criterion-validate.
context.cached_static_prompt = None context.cached_static_prompt = None
context.vertex_cache_name = None
# Check if prompt caching is enabled for extract-action # Check if prompt caching is enabled for extract-action
use_caching = False use_caching = False
@@ -2773,7 +2780,9 @@ class ForgeAgent:
prompt_name=EXTRACT_ACTION_PROMPT_NAME, prompt_name=EXTRACT_ACTION_PROMPT_NAME,
cache_variant=cache_variant, cache_variant=cache_variant,
) )
return combined_prompt, use_caching # Map template to prompt_name for logging/caching guards
prompt_name = EXTRACT_ACTION_PROMPT_NAME if template == EXTRACT_ACTION_TEMPLATE else template
return combined_prompt, use_caching, prompt_name
except Exception as e: except Exception as e:
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e)) LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
@@ -2799,7 +2808,9 @@ class ForgeAgent:
has_magic_link_page=context.has_magic_link_page(task.task_id), has_magic_link_page=context.has_magic_link_page(task.task_id),
) )
return full_prompt, use_caching # Map template to prompt_name for logging/caching guards
prompt_name = EXTRACT_ACTION_PROMPT_NAME if template == EXTRACT_ACTION_TEMPLATE else template
return full_prompt, use_caching, prompt_name
async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]: async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]:
""" """
@@ -4347,7 +4358,7 @@ class ForgeAgent:
current_context = skyvern_context.ensure_context() current_context = skyvern_context.ensure_context()
current_context.totp_codes[task.task_id] = otp_value.value current_context.totp_codes[task.task_id] = otp_value.value
extract_action_prompt, use_caching = await self._build_extract_action_prompt( extract_action_prompt, use_caching, prompt_name = await self._build_extract_action_prompt(
task, task,
step, step,
browser_state, browser_state,
@@ -4370,7 +4381,7 @@ class ForgeAgent:
prompt=extract_action_prompt, prompt=extract_action_prompt,
step=step, step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
prompt_name="extract-actions", prompt_name=prompt_name,
) )
return json_response return json_response

View File

@@ -406,10 +406,13 @@ class LLMAPIHandlerFactory:
return llm_request_json return llm_request_json
# Inject context caching system message when available # Inject context caching system message when available
# IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts
# (e.g., check-user-goal) with the extract-action schema
try: try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
if ( if (
context_cached_static_prompt context
and context.cached_static_prompt
and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions
and isinstance(llm_config, LLMConfig) and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str) and isinstance(llm_config.model_name, str)
): ):
@@ -426,7 +429,7 @@ class LLMAPIHandlerFactory:
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": context_cached_static_prompt, "text": context.cached_static_prompt,
} }
], ],
} }
@@ -789,10 +792,13 @@ class LLMAPIHandlerFactory:
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
# Inject context caching system message when available # Inject context caching system message when available
# IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts
# (e.g., check-user-goal) with the extract-action schema
try: try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
if ( if (
context_cached_static_prompt context
and context.cached_static_prompt
and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions
and isinstance(llm_config, LLMConfig) and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str) and isinstance(llm_config.model_name, str)
): ):
@@ -809,7 +815,7 @@ class LLMAPIHandlerFactory:
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": context_cached_static_prompt, "text": context.cached_static_prompt,
} }
], ],
} }