fix: prevent Vertex cache contamination across different prompt templates (#4183)
This commit is contained in:
@@ -143,6 +143,7 @@ class SpeculativePlan:
|
|||||||
use_caching: bool
|
use_caching: bool
|
||||||
llm_json_response: dict[str, Any] | None
|
llm_json_response: dict[str, Any] | None
|
||||||
llm_metadata: SpeculativeLLMMetadata | None = None
|
llm_metadata: SpeculativeLLMMetadata | None = None
|
||||||
|
prompt_name: str = "extract-actions"
|
||||||
|
|
||||||
|
|
||||||
class ActionLinkedNode:
|
class ActionLinkedNode:
|
||||||
@@ -945,11 +946,13 @@ class ForgeAgent:
|
|||||||
json_response = speculative_plan.llm_json_response
|
json_response = speculative_plan.llm_json_response
|
||||||
reuse_speculative_llm_response = json_response is not None
|
reuse_speculative_llm_response = json_response is not None
|
||||||
speculative_llm_metadata = speculative_plan.llm_metadata
|
speculative_llm_metadata = speculative_plan.llm_metadata
|
||||||
|
prompt_name = speculative_plan.prompt_name
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
scraped_page,
|
scraped_page,
|
||||||
extract_action_prompt,
|
extract_action_prompt,
|
||||||
use_caching,
|
use_caching,
|
||||||
|
prompt_name,
|
||||||
) = await self.build_and_record_step_prompt(
|
) = await self.build_and_record_step_prompt(
|
||||||
task,
|
task,
|
||||||
step,
|
step,
|
||||||
@@ -1014,7 +1017,7 @@ class ForgeAgent:
|
|||||||
if not reuse_speculative_llm_response:
|
if not reuse_speculative_llm_response:
|
||||||
json_response = await llm_api_handler(
|
json_response = await llm_api_handler(
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
prompt_name="extract-actions",
|
prompt_name=prompt_name,
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
)
|
)
|
||||||
@@ -1783,7 +1786,7 @@ class ForgeAgent:
|
|||||||
try:
|
try:
|
||||||
next_step.is_speculative = True
|
next_step.is_speculative = True
|
||||||
|
|
||||||
scraped_page, extract_action_prompt, use_caching = await self.build_and_record_step_prompt(
|
scraped_page, extract_action_prompt, use_caching, prompt_name = await self.build_and_record_step_prompt(
|
||||||
task,
|
task,
|
||||||
next_step,
|
next_step,
|
||||||
browser_state,
|
browser_state,
|
||||||
@@ -1798,7 +1801,7 @@ class ForgeAgent:
|
|||||||
|
|
||||||
llm_json_response = await llm_api_handler(
|
llm_json_response = await llm_api_handler(
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
prompt_name="extract-actions",
|
prompt_name=prompt_name,
|
||||||
step=next_step,
|
step=next_step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
)
|
)
|
||||||
@@ -1821,6 +1824,7 @@ class ForgeAgent:
|
|||||||
use_caching=use_caching,
|
use_caching=use_caching,
|
||||||
llm_json_response=llm_json_response,
|
llm_json_response=llm_json_response,
|
||||||
llm_metadata=metadata_copy,
|
llm_metadata=metadata_copy,
|
||||||
|
prompt_name=prompt_name,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -2288,7 +2292,7 @@ class ForgeAgent:
|
|||||||
engine: RunEngine,
|
engine: RunEngine,
|
||||||
*,
|
*,
|
||||||
persist_artifacts: bool = True,
|
persist_artifacts: bool = True,
|
||||||
) -> tuple[ScrapedPage, str, bool]:
|
) -> tuple[ScrapedPage, str, bool, str]:
|
||||||
# Check if we have pre-scraped data from parallel verification optimization
|
# Check if we have pre-scraped data from parallel verification optimization
|
||||||
context = skyvern_context.current()
|
context = skyvern_context.current()
|
||||||
scraped_page: ScrapedPage | None = None
|
scraped_page: ScrapedPage | None = None
|
||||||
@@ -2429,8 +2433,9 @@ class ForgeAgent:
|
|||||||
workflow_run_id=task.workflow_run_id,
|
workflow_run_id=task.workflow_run_id,
|
||||||
)
|
)
|
||||||
extract_action_prompt = ""
|
extract_action_prompt = ""
|
||||||
|
prompt_name = EXTRACT_ACTION_PROMPT_NAME # Default; overwritten below for non-CUA engines
|
||||||
if engine not in CUA_ENGINES:
|
if engine not in CUA_ENGINES:
|
||||||
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
|
extract_action_prompt, use_caching, prompt_name = await self._build_extract_action_prompt(
|
||||||
task,
|
task,
|
||||||
step,
|
step,
|
||||||
browser_state,
|
browser_state,
|
||||||
@@ -2466,7 +2471,7 @@ class ForgeAgent:
|
|||||||
data=element_tree_in_prompt.encode(),
|
data=element_tree_in_prompt.encode(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return scraped_page, extract_action_prompt, use_caching
|
return scraped_page, extract_action_prompt, use_caching, prompt_name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_extract_action_cache_variant(
|
def _build_extract_action_cache_variant(
|
||||||
@@ -2625,7 +2630,7 @@ class ForgeAgent:
|
|||||||
scraped_page: ScrapedPage,
|
scraped_page: ScrapedPage,
|
||||||
verification_code_check: bool = False,
|
verification_code_check: bool = False,
|
||||||
expire_verification_code: bool = False,
|
expire_verification_code: bool = False,
|
||||||
) -> tuple[str, bool]:
|
) -> tuple[str, bool, str]:
|
||||||
actions_and_results_str = await self._get_action_results(task)
|
actions_and_results_str = await self._get_action_results(task)
|
||||||
|
|
||||||
# Generate the extract action prompt
|
# Generate the extract action prompt
|
||||||
@@ -2682,8 +2687,10 @@ class ForgeAgent:
|
|||||||
|
|
||||||
context = skyvern_context.ensure_context()
|
context = skyvern_context.ensure_context()
|
||||||
|
|
||||||
# Reset cached prompt by default; we will set it below if caching is enabled.
|
# Reset cached prompt and cache reference by default; we will set them below if caching is enabled.
|
||||||
|
# This prevents extract-action cache from being attached to other prompts like decisive-criterion-validate.
|
||||||
context.cached_static_prompt = None
|
context.cached_static_prompt = None
|
||||||
|
context.vertex_cache_name = None
|
||||||
|
|
||||||
# Check if prompt caching is enabled for extract-action
|
# Check if prompt caching is enabled for extract-action
|
||||||
use_caching = False
|
use_caching = False
|
||||||
@@ -2773,7 +2780,9 @@ class ForgeAgent:
|
|||||||
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
|
prompt_name=EXTRACT_ACTION_PROMPT_NAME,
|
||||||
cache_variant=cache_variant,
|
cache_variant=cache_variant,
|
||||||
)
|
)
|
||||||
return combined_prompt, use_caching
|
# Map template to prompt_name for logging/caching guards
|
||||||
|
prompt_name = EXTRACT_ACTION_PROMPT_NAME if template == EXTRACT_ACTION_TEMPLATE else template
|
||||||
|
return combined_prompt, use_caching, prompt_name
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
|
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
|
||||||
@@ -2799,7 +2808,9 @@ class ForgeAgent:
|
|||||||
has_magic_link_page=context.has_magic_link_page(task.task_id),
|
has_magic_link_page=context.has_magic_link_page(task.task_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
return full_prompt, use_caching
|
# Map template to prompt_name for logging/caching guards
|
||||||
|
prompt_name = EXTRACT_ACTION_PROMPT_NAME if template == EXTRACT_ACTION_TEMPLATE else template
|
||||||
|
return full_prompt, use_caching, prompt_name
|
||||||
|
|
||||||
async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]:
|
async def _get_prompt_caching_settings(self, context: SkyvernContext) -> dict[str, bool]:
|
||||||
"""
|
"""
|
||||||
@@ -4347,7 +4358,7 @@ class ForgeAgent:
|
|||||||
current_context = skyvern_context.ensure_context()
|
current_context = skyvern_context.ensure_context()
|
||||||
current_context.totp_codes[task.task_id] = otp_value.value
|
current_context.totp_codes[task.task_id] = otp_value.value
|
||||||
|
|
||||||
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
|
extract_action_prompt, use_caching, prompt_name = await self._build_extract_action_prompt(
|
||||||
task,
|
task,
|
||||||
step,
|
step,
|
||||||
browser_state,
|
browser_state,
|
||||||
@@ -4370,7 +4381,7 @@ class ForgeAgent:
|
|||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
step=step,
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
prompt_name="extract-actions",
|
prompt_name=prompt_name,
|
||||||
)
|
)
|
||||||
return json_response
|
return json_response
|
||||||
|
|
||||||
|
|||||||
@@ -406,10 +406,13 @@ class LLMAPIHandlerFactory:
|
|||||||
return llm_request_json
|
return llm_request_json
|
||||||
|
|
||||||
# Inject context caching system message when available
|
# Inject context caching system message when available
|
||||||
|
# IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts
|
||||||
|
# (e.g., check-user-goal) with the extract-action schema
|
||||||
try:
|
try:
|
||||||
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
|
|
||||||
if (
|
if (
|
||||||
context_cached_static_prompt
|
context
|
||||||
|
and context.cached_static_prompt
|
||||||
|
and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions
|
||||||
and isinstance(llm_config, LLMConfig)
|
and isinstance(llm_config, LLMConfig)
|
||||||
and isinstance(llm_config.model_name, str)
|
and isinstance(llm_config.model_name, str)
|
||||||
):
|
):
|
||||||
@@ -426,7 +429,7 @@ class LLMAPIHandlerFactory:
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": context_cached_static_prompt,
|
"text": context.cached_static_prompt,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -789,10 +792,13 @@ class LLMAPIHandlerFactory:
|
|||||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||||
|
|
||||||
# Inject context caching system message when available
|
# Inject context caching system message when available
|
||||||
|
# IMPORTANT: Only inject for extract-actions prompt to avoid contaminating other prompts
|
||||||
|
# (e.g., check-user-goal) with the extract-action schema
|
||||||
try:
|
try:
|
||||||
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
|
|
||||||
if (
|
if (
|
||||||
context_cached_static_prompt
|
context
|
||||||
|
and context.cached_static_prompt
|
||||||
|
and prompt_name == EXTRACT_ACTION_PROMPT_NAME # Only inject for extract-actions
|
||||||
and isinstance(llm_config, LLMConfig)
|
and isinstance(llm_config, LLMConfig)
|
||||||
and isinstance(llm_config.model_name, str)
|
and isinstance(llm_config.model_name, str)
|
||||||
):
|
):
|
||||||
@@ -809,7 +815,7 @@ class LLMAPIHandlerFactory:
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": context_cached_static_prompt,
|
"text": context.cached_static_prompt,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user