Pedro/prompt caching (#3531)

This commit is contained in:
pedrohsdb
2025-09-25 15:04:54 -07:00
committed by GitHub
parent a1c94ec4b4
commit dd9d4fb3a9
5 changed files with 326 additions and 32 deletions

View File

@@ -929,6 +929,7 @@ class ForgeAgent:
(
scraped_page,
extract_action_prompt,
use_caching,
) = await self.build_and_record_step_prompt(
task,
step,
@@ -990,6 +991,12 @@ class ForgeAgent:
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
llm_key_override, default=app.LLM_API_HANDLER
)
# Add caching flag to context for monitoring
if use_caching:
context = skyvern_context.current()
if context:
context.use_prompt_caching = True
json_response = await llm_api_handler(
prompt=extract_action_prompt,
prompt_name="extract-actions",
@@ -1884,7 +1891,7 @@ class ForgeAgent:
browser_state: BrowserState,
scrape_type: ScrapeType,
engine: RunEngine,
) -> ScrapedPage:
) -> tuple[ScrapedPage, str, bool]:
if scrape_type == ScrapeType.NORMAL:
pass
@@ -1926,7 +1933,7 @@ class ForgeAgent:
step: Step,
browser_state: BrowserState,
engine: RunEngine,
) -> tuple[ScrapedPage, str]:
) -> tuple[ScrapedPage, str, bool]:
# start the async tasks while running scrape_website
if engine not in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
@@ -1937,9 +1944,11 @@ class ForgeAgent:
# second time: try again the normal scrape, (stopping window loading before scraping barely helps, but causing problem)
# third time: reload the page before scraping
scraped_page: ScrapedPage | None = None
extract_action_prompt = ""
use_caching = False
for idx, scrape_type in enumerate(SCRAPE_TYPE_ORDER):
try:
scraped_page = await self._scrape_with_type(
scraped_page, extract_action_prompt, use_caching = await self._scrape_with_type(
task=task,
step=step,
browser_state=browser_state,
@@ -1980,7 +1989,7 @@ class ForgeAgent:
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
extract_action_prompt = ""
if engine not in CUA_ENGINES:
extract_action_prompt = await self._build_extract_action_prompt(
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
task,
step,
browser_state,
@@ -2015,7 +2024,7 @@ class ForgeAgent:
data=element_tree_in_prompt.encode(),
)
return scraped_page, extract_action_prompt
return scraped_page, extract_action_prompt, use_caching
async def _build_extract_action_prompt(
self,
@@ -2025,7 +2034,7 @@ class ForgeAgent:
scraped_page: ScrapedPage,
verification_code_check: bool = False,
expire_verification_code: bool = False,
) -> str:
) -> tuple[str, bool]:
actions_and_results_str = await self._get_action_results(task)
# Generate the extract action prompt
@@ -2081,7 +2090,44 @@ class ForgeAgent:
context = skyvern_context.ensure_context()
return load_prompt_with_elements(
# Check if prompt caching is enabled for extract-action
use_caching = False
if (
template == "extract-action"
and LLMAPIHandlerFactory._prompt_caching_settings
and LLMAPIHandlerFactory._prompt_caching_settings.get("extract-action", False)
):
try:
# Try to load split templates for caching
static_prompt = prompt_engine.load_prompt(f"{template}-static")
dynamic_prompt = prompt_engine.load_prompt(
f"{template}-dynamic",
navigation_goal=navigation_goal,
navigation_payload_str=json.dumps(final_navigation_payload),
starting_url=starting_url,
current_url=current_url,
data_extraction_goal=task.data_extraction_goal,
action_history=actions_and_results_str,
error_code_mapping_str=(json.dumps(task.error_code_mapping) if task.error_code_mapping else None),
local_datetime=datetime.now(context.tz_info).isoformat(),
verification_code_check=verification_code_check,
complete_criterion=task.complete_criterion.strip() if task.complete_criterion else None,
terminate_criterion=task.terminate_criterion.strip() if task.terminate_criterion else None,
parse_select_feature_enabled=context.enable_parse_select_in_extract,
)
# Store static prompt for caching and return dynamic prompt
context.cached_static_prompt = static_prompt
use_caching = True
LOG.info("Using cached prompt for extract-action", task_id=task.task_id)
return dynamic_prompt, use_caching
except Exception as e:
LOG.warning("Failed to load cached prompt templates, falling back to original", error=str(e))
# Fall through to original behavior
# Original behavior - load full prompt
full_prompt = load_prompt_with_elements(
element_tree_builder=scraped_page,
prompt_engine=prompt_engine,
template_name=template,
@@ -2099,6 +2145,8 @@ class ForgeAgent:
parse_select_feature_enabled=context.enable_parse_select_in_extract,
)
return full_prompt, use_caching
def _build_navigation_payload(
self,
task: Task,
@@ -2987,7 +3035,7 @@ class ForgeAgent:
current_context = skyvern_context.ensure_context()
current_context.totp_codes[task.task_id] = verification_code
extract_action_prompt = await self._build_extract_action_prompt(
extract_action_prompt, use_caching = await self._build_extract_action_prompt(
task,
step,
browser_state,
@@ -3000,6 +3048,12 @@ class ForgeAgent:
llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(
llm_key_override, default=app.LLM_API_HANDLER
)
# Add caching flag to context for monitoring
if use_caching:
context = skyvern_context.current()
if context:
context.use_prompt_caching = True
return await llm_api_handler(
prompt=extract_action_prompt,
step=step,

View File

@@ -0,0 +1,41 @@
```
{{ action_history }}
```
{% if complete_criterion %}
Complete criterion:
```
{{ complete_criterion }}
```{% endif %}
User goal:
```
{{ navigation_goal }}
```
{% if error_code_mapping_str %}
Use the error codes and their descriptions to surface user-defined errors. Do not return any error that's not defined by the user. User defined errors:
```
{{ error_code_mapping_str }}
```{% endif %}
{% if data_extraction_goal %}
User Data Extraction Goal:
```
{{ data_extraction_goal }}
```
{% endif %}
User details:
```
{{ navigation_payload_str }}
```
Clickable elements from `{{ current_url }}`:
```
{{ elements }}
```
The URL of the page you're on right now is `{{ current_url }}`.
Current datetime, ISO format:
```
{{ local_datetime }}
```

View File

@@ -0,0 +1,46 @@
Identify actions to help user progress towards the user goal using the DOM elements given in the list and the screenshot of the website.
Include only the elements that are relevant to the user goal, without altering or imagining new elements.
Accurately interpret and understand the functional significance of SVG elements based on their shapes and context within the webpage.
Use the user details to fill in necessary values. Always satisfy required fields if the field isn't already filled in. Don't return any action for the same field, if this field is already filled in and the value is the same as the one you would have filled in.
MAKE SURE YOU OUTPUT VALID JSON. No text before or after JSON, no trailing commas, no comments (//), no unnecessary quotes, etc.
Each interactable element is tagged with an ID. Avoid taking action on a disabled element when there is an alternative action available.
If you see any information in red in the page screenshot, this means a condition wasn't satisfied. prioritize actions with the red information.
If you see a popup in the page screenshot, prioritize actions on the popup.
Reply in JSON format with the following keys:
{
"user_goal_stage": str, // A string to describe the reasoning whether user goal has been achieved or not.
"user_goal_achieved": bool, // True if the user goal has been completed, otherwise False.
"action_plan": str, // A string that describes the plan of actions you're going to take. Be specific and to the point. Use this as a quick summary of the actions you're going to take, and what order you're going to take them in, and how that moves you towards your overall goal. Output "COMPLETE" action in the "actions" if user_goal_achieved is True. Output "TERMINATE" action in the "actions" if your plan is to terminate the process.
"actions": array // An array of actions. Here's the format of each action:
[{
"reasoning": str, // The reasoning behind the action. This reasoning must be user information agnostic. Mention why you chose the action type, and why you chose the element id. Keep the reasoning short and to the point.
"user_detail_query": str, // Think of this value as a Jeopardy question and the intention behind the action. Ask the user for the details you need for executing this action. Ask the question even if the details are disclosed in user goal or user details. If it's a text field, ask for the text. If it's a file upload, ask for the file. If it's a dropdown, ask for the relevant information. If you are clicking on something specific, ask about what the intention is behind the click and what to click on. If you're downloading a file and you have multiple options, ask the user which one to download. Examples are: "What product ID should I input into the search bar?", "What file should I upload?", "What is the previous insurance provider of the user?", "Which invoice should I download?", "Does the user have any pets?". If the action doesn't require any user details, describe the intention behind the action.
"user_detail_answer": str, // The answer to the `user_detail_query`. The source of this answer can be user goal or user details.
"confidence_float": float, // The confidence of the action. Pick a number between 0.0 and 1.0. 0.0 means no confidence, 1.0 means full confidence
"action_type": str, // It's a string enum: "CLICK", "INPUT_TEXT", "UPLOAD_FILE", "SELECT_OPTION", "WAIT", "SOLVE_CAPTCHA", "COMPLETE", "TERMINATE". "CLICK" is an element you'd like to click. "INPUT_TEXT" is an element you'd like to input text into. "UPLOAD_FILE" is an element you'd like to upload a file into. "SELECT_OPTION" is an element you'd like to select an option from. "WAIT" action should be used if there are no actions to take and there is some indication on screen that waiting could yield more actions. "WAIT" should not be used if there are actions to take. "SOLVE_CAPTCHA" should be used if there's a captcha to solve on the screen. "COMPLETE" is used when the user goal has been achieved AND if there's any data extraction goal, you should be able to get data from the page. Never return a COMPLETE action unless the user goal is achieved. "TERMINATE" is used to terminate the whole task with a failure when it doesn't seem like the user goal can be achieved. Do not use "TERMINATE" if waiting could lead the user towards the goal. Only return "TERMINATE" if you are on a page where the user goal cannot be achieved. All other actions are ignored when "TERMINATE" is returned.
"id": str, // The id of the element to take action on. The id has to be one from the elements list
"text": str, // Text for INPUT_TEXT action only
"file_url": str, // The url of the file to upload if applicable. This field must be present for UPLOAD_FILE but can also be present for CLICK only if the click is to upload the file. It should be null otherwise.
"download": bool, // Can only be true for CLICK actions. If true, the browser will trigger a download by clicking the element. If false, the browser will click the element without triggering a download.
"option": { // The option to select for SELECT_OPTION action only. null if not SELECT_OPTION action
"label": str, // the label of the option if any. MAKE SURE YOU USE THIS LABEL TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION LABEL HERE
"index": int, // the index corresponding to the option index under the select element.
"value": str // the value of the option. MAKE SURE YOU USE THIS VALUE TO SELECT THE OPTION. DO NOT PUT ANYTHING OTHER THAN A VALID OPTION VALUE HERE
}{% if parse_select_feature_enabled %},
"context": { // The context for INPUT_TEXT or SELECT_OPTION action only. null if not INPUT_TEXT or SELECT_OPTION action. Extract the following detailed information from the "reasoning", and double-check the information by analysing the HTML elements.
"thought": str, // A string to describe how you double-check the context information to ensure the accuracy.
"field": str, // Which field is this action intended to fill out?
"is_required": bool, // True if this is a required field, otherwise false.
"is_search_bar": bool, // True if the element to take the action is a search bar, otherwise false.
"is_location_input": bool, // True if the element is asking user to input where he lives, otherwise false. For example, it is asking for location, or address, or other similar information. Output False if it only requires ZIP code or postal code.
"is_date_related": bool, // True if the field is related to date input or select, otherwise false.
}{% endif %}
}],{% if verification_code_check %}
"verification_code_reasoning": str, // Let's think step by step. Describe what you see and think if there is somewhere on the current page where you must enter the verification code now for login or any verification step. Explain why you believe a verification code needs to be entered somewhere or not. Do not imagine any place to enter the code if the code has not been sent yet.
"place_to_enter_verification_code": bool, // Whether there is a place on the current page to enter the verification code now.
"should_enter_verification_code": bool // Whether the user should proceed to enter the verification code {% endif %}
}
Consider the action history from the last step and the screenshot together, if actions from the last step don't yield positive impact, try other actions or other action combinations.
Action history from previous steps: (note: even if the action history suggests goal is achieved, check the screenshot and the DOM elements to make sure the goal is achieved)

View File

@@ -48,6 +48,7 @@ class LLMCallStats(BaseModel):
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}
_thinking_budget_settings: dict[str, int] | None = None
_prompt_caching_settings: dict[str, bool] | None = None
@staticmethod
def _apply_thinking_budget_optimization(
@@ -270,8 +271,62 @@ class LLMAPIHandlerFactory:
task_v2=task_v2,
thought=thought,
)
# Build messages and apply caching in one step
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
# Inject context caching system message when available
try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
if (
context_cached_static_prompt
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
# Check if this is a Vertex AI model
if "vertex_ai/" in llm_config.model_name:
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied Vertex context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
ttl_seconds=3600,
)
# Check if this is an OpenAI model
elif (
llm_config.model_name.startswith("gpt-")
or llm_config.model_name.startswith("o1-")
or llm_config.model_name.startswith("o3-")
):
# For OpenAI models, we need to add the cached content as a system message
# and mark it for caching using the cache_control parameter
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied OpenAI context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
)
except Exception as e:
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
@@ -343,16 +398,28 @@ class LLMAPIHandlerFactory:
except Exception as e:
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
if hasattr(response, "usage") and response.usage:
prompt_tokens = getattr(response.usage, "prompt_tokens", 0)
completion_tokens = getattr(response.usage, "completion_tokens", 0)
# Extract reasoning tokens from completion_tokens_details
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
if completion_token_detail:
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
# Extract cached tokens from prompt_tokens_details
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
if cached_token_detail:
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
if step:
await app.DATABASE.update_step(
task_id=step.task_id,
@@ -492,6 +559,59 @@ class LLMAPIHandlerFactory:
model_name = llm_config.model_name
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
# Inject context caching system message when available
try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
if (
context_cached_static_prompt
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
# Check if this is a Vertex AI model
if "vertex_ai/" in llm_config.model_name:
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied Vertex context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
ttl_seconds=3600,
)
# Check if this is an OpenAI model
elif (
llm_config.model_name.startswith("gpt-")
or llm_config.model_name.startswith("o1-")
or llm_config.model_name.startswith("o3-")
):
# For OpenAI models, we need to add the cached content as a system message
# and mark it for caching using the cache_control parameter
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied OpenAI context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
)
except Exception as e:
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
@@ -573,16 +693,28 @@ class LLMAPIHandlerFactory:
except Exception as e:
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
if hasattr(response, "usage") and response.usage:
prompt_tokens = getattr(response.usage, "prompt_tokens", 0)
completion_tokens = getattr(response.usage, "completion_tokens", 0)
# Extract reasoning tokens from completion_tokens_details
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
if completion_token_detail:
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
# Extract cached tokens from prompt_tokens_details
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
if cached_token_detail:
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
if step:
await app.DATABASE.update_step(
@@ -684,6 +816,13 @@ class LLMAPIHandlerFactory:
if settings:
LOG.info("Thinking budget optimization settings applied", settings=settings)
@classmethod
def set_prompt_caching_settings(cls, settings: dict[str, bool] | None) -> None:
"""Set prompt caching optimization settings for the current task/workflow."""
cls._prompt_caching_settings = settings
if settings:
LOG.info("Prompt caching optimization settings applied", settings=settings)
class LLMCaller:
"""
@@ -1085,16 +1224,28 @@ class LLMCaller:
except Exception as e:
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
input_tokens = response.get("usage", {}).get("prompt_tokens", 0)
output_tokens = response.get("usage", {}).get("completion_tokens", 0)
input_tokens = 0
output_tokens = 0
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
if hasattr(response, "usage") and response.usage:
input_tokens = getattr(response.usage, "prompt_tokens", 0)
output_tokens = getattr(response.usage, "completion_tokens", 0)
# Extract reasoning tokens from completion_tokens_details
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
if completion_token_detail:
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
# Extract cached tokens from prompt_tokens_details
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
if cached_token_detail:
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=input_tokens,

View File

@@ -33,6 +33,8 @@ class SkyvernContext:
action_order: int = 0
prompt: str | None = None
enable_parse_select_in_extract: bool = False
use_prompt_caching: bool = False
cached_static_prompt: str | None = None
def __repr__(self) -> str:
return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})"