Pedro/prompt caching (#3531)
This commit is contained in:
@@ -929,6 +929,7 @@ class ForgeAgent:
|
||||
(
|
||||
scraped_page,
|
||||
extract_action_prompt,
|
||||
use_caching,
|
||||
) = await self.build_and_record_step_prompt(
|
||||
task,
|
||||
step,
|
||||
@@ -990,6 +991,12 @@ class ForgeAgent:
|
||||
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
llm_key_override, default=app.LLM_API_HANDLER
|
||||
)
|
||||
# Add caching flag to context for monitoring
|
||||
if use_caching:
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
context.use_prompt_caching = True
|
||||
|
||||
json_response = await llm_api_handler(
|
||||
prompt=extract_action_prompt,
|
||||
prompt_name="extract-actions",
|
||||
@@ -1884,7 +1891,7 @@ class ForgeAgent:
|
||||
browser_state: BrowserState,
|
||||
scrape_type: ScrapeType,
|
||||
engine: RunEngine,
|
||||
) -> ScrapedPage:
|
||||
) -> tuple[ScrapedPage, str, bool]:
|
||||
if scrape_type == ScrapeType.NORMAL:
|
||||
pass
|
||||
|
||||
@@ -1926,7 +1933,7 @@ class ForgeAgent:
|
||||
step: Step,
|
||||
browser_state: BrowserState,
|
||||
engine: RunEngine,
|
||||
) -> tuple[ScrapedPage, str]:
|
||||
) -> tuple[ScrapedPage, str, bool]:
|
||||
# start the async tasks while running scrape_website
|
||||
if engine not in CUA_ENGINES:
|
||||
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
|
||||
@@ -1937,9 +1944,11 @@ class ForgeAgent:
|
||||
# second time: try again the normal scrape, (stopping window loading before scraping barely helps, but causing problem)
|
||||
# third time: reload the page before scraping
|
||||
scraped_page: ScrapedPage | None = None
|
||||
extract_action_prompt = ""
|
||||
use_caching = False
|
||||
for idx, scrape_type in enumerate(SCRAPE_TYPE_ORDER):
|
||||
try:
|
||||
scraped_page = await self._scrape_with_type(
|
||||
scraped_page, extract_action_prompt, use_caching = await self._scrape_with_type(
|
||||
task=task,
|
||||
step=step,
|
||||
browser_state=browser_state,
|
||||
@@ -1980,7 +1989,7 @@ class ForgeAgent:
|
||||
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
|
||||
extract_action_prompt = ""
|
||||
if engine not in CUA_ENGINES:
|
||||
extract_action_prompt = await self._build_extract_action_prompt(
|
||||
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
@@ -2015,7 +2024,7 @@ class ForgeAgent:
|
||||
data=element_tree_in_prompt.encode(),
|
||||
)
|
||||
|
||||
return scraped_page, extract_action_prompt
|
||||
return scraped_page, extract_action_prompt, use_caching
|
||||
|
||||
async def _build_extract_action_prompt(
|
||||
self,
|
||||
@@ -2025,7 +2034,7 @@ class ForgeAgent:
|
||||
scraped_page: ScrapedPage,
|
||||
verification_code_check: bool = False,
|
||||
expire_verification_code: bool = False,
|
||||
) -> str:
|
||||
) -> tuple[str, bool]:
|
||||
actions_and_results_str = await self._get_action_results(task)
|
||||
|
||||
# Generate the extract action prompt
|
||||
@@ -2081,7 +2090,44 @@ class ForgeAgent:
|
||||
|
||||
context = skyvern_context.ensure_context()
|
||||
|
||||
return load_prompt_with_elements(
|
||||
# Check if prompt caching is enabled for extract-action
|
||||
use_caching = False
|
||||
if (
|
||||
template == "extract-action"
|
||||
and LLMAPIHandlerFactory._prompt_caching_settings
|
||||
and LLMAPIHandlerFactory._prompt_caching_settings.get("extract-action", False)
|
||||
):
|
||||
try:
|
||||
# Try to load split templates for caching
|
||||
static_prompt = prompt_engine.load_prompt(f"{template}-static")
|
||||
dynamic_prompt = prompt_engine.load_prompt(
|
||||
f"{template}-dynamic",
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload_str=json.dumps(final_navigation_payload),
|
||||
starting_url=starting_url,
|
||||
current_url=current_url,
|
||||
data_extraction_goal=task.data_extraction_goal,
|
||||
action_history=actions_and_results_str,
|
||||
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
|
||||
local_datetime=datetime.now(context.tz_info).isoformat(),
|
||||
verification_code_check=verification_code_check,
|
||||
complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None,
|
||||
terminate_criterion=task.terminate_criterion.strip() if task.terminate_criterion else None,
|
||||
parse_select_feature_enabled=context.enable_parse_select_in_extract,
|
||||
)
|
||||
|
||||
# Store static prompt for caching and return dynamic prompt
|
||||
context.cached_static_prompt = static_prompt
|
||||
use_caching = True
|
||||
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
|
||||
return dynamic_prompt, use_caching
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
|
||||
# Fall through to original behavior
|
||||
|
||||
# Original behavior - load full prompt
|
||||
full_prompt = load_prompt_with_elements(
|
||||
element_tree_builder=scraped_page,
|
||||
prompt_engine=prompt_engine,
|
||||
template_name=template,
|
||||
@@ -2099,6 +2145,8 @@ class ForgeAgent:
|
||||
parse_select_feature_enabled=context.enable_parse_select_in_extract,
|
||||
)
|
||||
|
||||
return full_prompt, use_caching
|
||||
|
||||
def _build_navigation_payload(
|
||||
self,
|
||||
task: Task,
|
||||
@@ -2987,7 +3035,7 @@ class ForgeAgent:
|
||||
current_context = skyvern_context.ensure_context()
|
||||
current_context.totp_codes[task.task_id] = verification_code
|
||||
|
||||
extract_action_prompt = await self._build_extract_action_prompt(
|
||||
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
|
||||
task,
|
||||
step,
|
||||
browser_state,
|
||||
@@ -3000,6 +3048,12 @@ class ForgeAgent:
|
||||
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
|
||||
llm_key_override, default=app.LLM_API_HANDLER
|
||||
)
|
||||
# Add caching flag to context for monitoring
|
||||
if use_caching:
|
||||
context = skyvern_context.current()
|
||||
if context:
|
||||
context.use_prompt_caching = True
|
||||
|
||||
return await llm_api_handler(
|
||||
prompt=extract_action_prompt,
|
||||
step=step,
|
||||
|
||||
Reference in New Issue
Block a user