Pedro/prompt caching (#3531)

This commit is contained in:
pedrohsdb
2025-09-25 15:04:54 -07:00
committed by GitHub
parent a1c94ec4b4
commit dd9d4fb3a9
5 changed files with 326 additions and 32 deletions

View File

@@ -929,6 +929,7 @@ class ForgeAgent:
(
scraped_page,
extract_action_prompt,
use_caching,
) = await self.build_and_record_step_prompt(
task,
step,
@@ -990,6 +991,12 @@ class ForgeAgent:
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
llm_key_override, default=app.LLM_API_HANDLER
)
# Add caching flag to context for monitoring
if use_caching:
context = skyvern_context.current()
if context:
context.use_prompt_caching = True
json_response = await llm_api_handler(
prompt=extract_action_prompt,
prompt_name="extract-actions",
@@ -1884,7 +1891,7 @@ class ForgeAgent:
browser_state: BrowserState,
scrape_type: ScrapeType,
engine: RunEngine,
) -> ScrapedPage:
) -> tuple[ScrapedPage, str, bool]:
if scrape_type == ScrapeType.NORMAL:
pass
@@ -1926,7 +1933,7 @@ class ForgeAgent:
step: Step,
browser_state: BrowserState,
engine: RunEngine,
) -> tuple[ScrapedPage, str]:
) -> tuple[ScrapedPage, str, bool]:
# start the async tasks while running scrape_website
if engine not in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
@@ -1937,9 +1944,11 @@ class ForgeAgent:
# second time: try again the normal scrape, (stopping window loading before scraping barely helps, but causing problem)
# third time: reload the page before scraping
scraped_page: ScrapedPage | None = None
extract_action_prompt = ""
use_caching = False
for idx, scrape_type in enumerate(SCRAPE_TYPE_ORDER):
try:
scraped_page = await self._scrape_with_type(
scraped_page, extract_action_prompt, use_caching = await self._scrape_with_type(
task=task,
step=step,
browser_state=browser_state,
@@ -1980,7 +1989,7 @@ class ForgeAgent:
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
extract_action_prompt = ""
if engine not in CUA_ENGINES:
extract_action_prompt = await self._build_extract_action_prompt(
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
task,
step,
browser_state,
@@ -2015,7 +2024,7 @@ class ForgeAgent:
data=element_tree_in_prompt.encode(),
)
return scraped_page, extract_action_prompt
return scraped_page, extract_action_prompt, use_caching
async def _build_extract_action_prompt(
self,
@@ -2025,7 +2034,7 @@ class ForgeAgent:
scraped_page: ScrapedPage,
verification_code_check: bool = False,
expire_verification_code: bool = False,
) -> str:
) -> tuple[str, bool]:
actions_and_results_str = await self._get_action_results(task)
# Generate the extract action prompt
@@ -2081,7 +2090,44 @@ class ForgeAgent:
context = skyvern_context.ensure_context()
return load_prompt_with_elements(
# Check if prompt caching is enabled for extract-action
use_caching = False
if (
template == "extract-action"
and LLMAPIHandlerFactory._prompt_caching_settings
and LLMAPIHandlerFactory._prompt_caching_settings.get("extract-action", False)
):
try:
# Try to load split templates for caching
static_prompt = prompt_engine.load_prompt(f"{template}-static")
dynamic_prompt = prompt_engine.load_prompt(
f"{template}-dynamic",
navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(final_navigation_payload),
starting_url=starting_url,
current_url=current_url,
data_extraction_goal=task.data_extraction_goal,
action_history=actions_and_results_str,
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
local_datetime=datetime.now(context.tz_info).isoformat(),
verification_code_check=verification_code_check,
complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None,
terminate_criterion=task.terminate_criterion.strip() if task.terminate_criterion else None,
parse_select_feature_enabled=context.enable_parse_select_in_extract,
)
# Store static prompt for caching and return dynamic prompt
context.cached_static_prompt = static_prompt
use_caching = True
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
return dynamic_prompt, use_caching
except Exception as e:
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
# Fall through to original behavior
# Original behavior - load full prompt
full_prompt = load_prompt_with_elements(
element_tree_builder=scraped_page,
prompt_engine=prompt_engine,
template_name=template,
@@ -2099,6 +2145,8 @@ class ForgeAgent:
parse_select_feature_enabled=context.enable_parse_select_in_extract,
)
return full_prompt, use_caching
def _build_navigation_payload(
self,
task: Task,
@@ -2987,7 +3035,7 @@ class ForgeAgent:
current_context = skyvern_context.ensure_context()
current_context.totp_codes[task.task_id] = verification_code
extract_action_prompt = await self._build_extract_action_prompt(
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
task,
step,
browser_state,
@@ -3000,6 +3048,12 @@ class ForgeAgent:
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
llm_key_override, default=app.LLM_API_HANDLER
)
# Add caching flag to context for monitoring
if use_caching:
context = skyvern_context.current()
if context:
context.use_prompt_caching = True
return await llm_api_handler(
prompt=extract_action_prompt,
step=step,