fix: prevent Vertex cache contamination across different prompt templates (#4183)
This commit is contained in:
@@ -143,6 +143,7 @@ class SpeculativePlan:
|
||||
use_caching: bool
|
||||
llm_json_response: dict[str, Any] | None
|
||||
llm_metadata: SpeculativeLLMMetadata | None = None
|
||||
prompt_name: str = "extract-actions"
|
||||
|
||||
|
||||
class ActionLinkedNode:
|
||||
@@ -945,11 +946,13 @@ class ForgeAgent:
|
||||
json_response = speculative_plan.llm_json_response
|
||||
reuse_speculative_llm_response = json_response is not None
|
||||
speculative_llm_metadata = speculative_plan.llm_metadata
|
||||
prompt_name = speculative_plan.prompt_name
|
||||
else:
|
||||
(
|
||||
scraped_page,
|
||||
extract_action_prompt,
|
||||
use_caching,
|
||||
prompt_name,
|
||||
) = await self.build_and_record_step_prompt(
|
||||
task,
|
||||
step,
|
||||
@@ -1014,7 +1017,7 @@ class ForgeAgent:
|
||||
if not reuse_speculative_llm_response:
|
||||
json_response = await llm_api_handler(
|
||||
prompt=extract_action_prompt,
|
||||
prompt_name="extract-actions",
|
||||
prompt_name=prompt_name,
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
)
|
||||
@@ -1783,7 +1786,7 @@ class ForgeAgent:
|
||||
try:
|
||||
next_step.is_speculative = True
|
||||
|
||||
scraped_page, extract_action_prompt, use_caching = await self.build_and_record_step_prompt(
|
||||
scraped_page, extract_action_prompt, use_caching, prompt_name = await self.build_and_record_step_prompt(
|
||||
task,
|
||||
next_step,
|
||||
browser_state,
|
||||
@@ -1798,7 +1801,7 @@ class ForgeAgent:
|
||||
|
||||
llm_json_response = await llm_api_handler(
|
||||
prompt=extract_action_prompt,
|
||||
prompt_name="extract-actions",
|
||||
prompt_name=prompt_name,
|
||||
step=next_step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
)
|
||||
@@ -1821,6 +1824,7 @@ class ForgeAgent:
|
||||
use_caching=use_caching,
|
||||
llm_json_response=llm_json_response,
|
||||
llm_metadata=metadata_copy,
|
||||
prompt_name=prompt_name,
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
@@ -2288,7 +2292,7 @@ class ForgeAgent:
|
||||
engine: RunEngine,
|
||||
*,
|
||||
persist_artifacts: bool = True,
|
||||
) -> tuple[ScrapedPage, str, bool]:
|
||||
) -> tuple[ScrapedPage, str, bool, str]:
|
||||
# Check if we have pre-scraped data from parallel verification optimization
|
||||
context = skyvern_context.current()
|
||||
scraped_page: ScrapedPage | None = None
|
||||
@@ -2429,8 +2433,9 @@ class ForgeAgent:
|
||||
workflow_run_id=task.workflow_run_id,
|
||||
)
|
||||
extract_action_prompt = ""
|
||||
prompt_name = EXTRACT_ACTION_PROMPT_NAME # Default; overwritten below for non-CUA engines
|
||||
if engine not in CUA_ENGINES:
|
||||
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
|
||||
extract_action_prompt, use_caching, prompt_name = await self._build_extract_action_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
@@ -2466,7 +2471,7 @@ class ForgeAgent:
|
||||
data=element_tree_in_prompt.encode(),
|
||||
)
|
||||
|
||||
return scraped_page, extract_action_prompt, use_caching
|
||||
return scraped_page, extract_action_prompt, use_caching, prompt_name
|
||||
|
||||
@staticmethod
|
||||
def _build_extract_action_cache_variant(
|
||||
@@ -2625,7 +2630,7 @@ class ForgeAgent:
|
||||
scraped_page: ScrapedPage,
|
||||
verification_code_check: bool = False,
|
||||
expire_verification_code: bool = False,
|
||||
) -> tuple[str, bool]:
|
||||
) -> tuple[str, bool, str]:
|
||||
actions_and_results_str = await self._get_action_results(task)
|
||||
|
||||
# Generate the extract action prompt
|
||||
@@ -2682,8 +2687,10 @@ class ForgeAgent:
|
||||
|
||||
context = skyvern_context.ensure_context()
|
||||
|
||||
# Reset cached prompt by default; we will set it below if caching is enabled.
|
||||
# Reset cached prompt and cache reference by default; we will set them below if caching is enabled.
|
||||
# This prevents extract-action cache from being attached to other prompts like decisive-criterion-validate.
|
||||
context.cached_static_prompt = None
|
||||
context.vertex_cache_name = None
|
||||
|
||||
# Check if prompt caching is enabled for extract-action
|
||||
use_caching = False
|
||||
@@ -2773,7 +2780,9 @@ class ForgeAgent:
|
||||
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
|
||||
cache_variant=cache_variant,
|
||||
)
|
||||
return combined_prompt, use_caching
|
||||
# Map template to prompt_name for logging/caching guards
|
||||
prompt_name = EXTRACT_ACTION_PROMPT_NAME if template == EXTRACT_ACTION_TEMPLATE else template
|
||||
return combined_prompt, use_caching, prompt_name
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
|
||||
@@ -2799,7 +2808,9 @@ class ForgeAgent:
|
||||
has_magic_link_page=context.has_magic_link_page(task.task_id),
|
||||
)
|
||||
|
||||
return full_prompt, use_caching
|
||||
# Map template to prompt_name for logging/caching guards
|
||||
prompt_name = EXTRACT_ACTION_PROMPT_NAME if template == EXTRACT_ACTION_TEMPLATE else template
|
||||
return full_prompt, use_caching, prompt_name
|
||||
|
||||
async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]:
|
||||
"""
|
||||
@@ -4347,7 +4358,7 @@ class ForgeAgent:
|
||||
current_context = skyvern_context.ensure_context()
|
||||
current_context.totp_codes[task.task_id] = otp_value.value
|
||||
|
||||
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
|
||||
extract_action_prompt, use_caching, prompt_name = await self._build_extract_action_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
@@ -4370,7 +4381,7 @@ class ForgeAgent:
|
||||
prompt=extract_action_prompt,
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
prompt_name="extract-actions",
|
||||
prompt_name=prompt_name,
|
||||
)
|
||||
return json_response
|
||||
|
||||
|
||||
@@ -406,10 +406,13 @@ class LLMAPIHandlerFactory:
|
||||
return llm_request_json
|
||||
|
||||
# Inject context caching system message when available
|
||||
# IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts
|
||||
# (e.g., check-user-goal) with the extract-action schema
|
||||
try:
|
||||
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
|
||||
if (
|
||||
context_cached_static_prompt
|
||||
context
|
||||
and context.cached_static_prompt
|
||||
and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions
|
||||
and isinstance(llm_config, LLMConfig)
|
||||
and isinstance(llm_config.model_name, str)
|
||||
):
|
||||
@@ -426,7 +429,7 @@ class LLMAPIHandlerFactory:
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": context_cached_static_prompt,
|
||||
"text": context.cached_static_prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -789,10 +792,13 @@ class LLMAPIHandlerFactory:
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
|
||||
# Inject context caching system message when available
|
||||
# IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts
|
||||
# (e.g., check-user-goal) with the extract-action schema
|
||||
try:
|
||||
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
|
||||
if (
|
||||
context_cached_static_prompt
|
||||
context
|
||||
and context.cached_static_prompt
|
||||
and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions
|
||||
and isinstance(llm_config, LLMConfig)
|
||||
and isinstance(llm_config.model_name, str)
|
||||
):
|
||||
@@ -809,7 +815,7 @@ class LLMAPIHandlerFactory:
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": context_cached_static_prompt,
|
||||
"text": context.cached_static_prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user