Pedro/prompt caching (#3531)

This commit is contained in:
pedrohsdb
2025-09-25 15:04:54 -07:00
committed by GitHub
parent a1c94ec4b4
commit dd9d4fb3a9
5 changed files with 326 additions and 32 deletions

View File

@@ -48,6 +48,7 @@ class LLMCallStats(BaseModel):
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}
_thinking_budget_settings: dict[str, int] | None = None
_prompt_caching_settings: dict[str, bool] | None = None
@staticmethod
def _apply_thinking_budget_optimization(
@@ -270,8 +271,62 @@ class LLMAPIHandlerFactory:
task_v2=task_v2,
thought=thought,
)
# Build messages and apply caching in one step
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
# Inject context caching system message when available
try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
if (
context_cached_static_prompt
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
# Check if this is a Vertex AI model
if "vertex_ai/" in llm_config.model_name:
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied Vertex context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
ttl_seconds=3600,
)
# Check if this is an OpenAI model
elif (
llm_config.model_name.startswith("gpt-")
or llm_config.model_name.startswith("o1-")
or llm_config.model_name.startswith("o3-")
):
# For OpenAI models, we need to add the cached content as a system message
# and mark it for caching using the cache_control parameter
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied OpenAI context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
)
except Exception as e:
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
@@ -343,16 +398,28 @@ class LLMAPIHandlerFactory:
except Exception as e:
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
if hasattr(response, "usage") and response.usage:
prompt_tokens = getattr(response.usage, "prompt_tokens", 0)
completion_tokens = getattr(response.usage, "completion_tokens", 0)
# Extract reasoning tokens from completion_tokens_details
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
if completion_token_detail:
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
# Extract cached tokens from prompt_tokens_details
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
if cached_token_detail:
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
if step:
await app.DATABASE.update_step(
task_id=step.task_id,
@@ -492,6 +559,59 @@ class LLMAPIHandlerFactory:
model_name = llm_config.model_name
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
# Inject context caching system message when available
try:
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
if (
context_cached_static_prompt
and isinstance(llm_config, LLMConfig)
and isinstance(llm_config.model_name, str)
):
# Check if this is a Vertex AI model
if "vertex_ai/" in llm_config.model_name:
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied Vertex context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
ttl_seconds=3600,
)
# Check if this is an OpenAI model
elif (
llm_config.model_name.startswith("gpt-")
or llm_config.model_name.startswith("o1-")
or llm_config.model_name.startswith("o3-")
):
# For OpenAI models, we need to add the cached content as a system message
# and mark it for caching using the cache_control parameter
caching_system_message = {
"role": "system",
"content": [
{
"type": "text",
"text": context_cached_static_prompt,
}
],
}
messages = [caching_system_message] + messages
LOG.info(
"Applied OpenAI context caching",
prompt_name=prompt_name,
model=llm_config.model_name,
)
except Exception as e:
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
@@ -573,16 +693,28 @@ class LLMAPIHandlerFactory:
except Exception as e:
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
if hasattr(response, "usage") and response.usage:
prompt_tokens = getattr(response.usage, "prompt_tokens", 0)
completion_tokens = getattr(response.usage, "completion_tokens", 0)
# Extract reasoning tokens from completion_tokens_details
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
if completion_token_detail:
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
# Extract cached tokens from prompt_tokens_details
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
if cached_token_detail:
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
if step:
await app.DATABASE.update_step(
@@ -684,6 +816,13 @@ class LLMAPIHandlerFactory:
if settings:
LOG.info("Thinking budget optimization settings applied", settings=settings)
@classmethod
def set_prompt_caching_settings(cls, settings: dict[str, bool] | None) -> None:
"""Set prompt caching optimization settings for the current task/workflow."""
cls._prompt_caching_settings = settings
if settings:
LOG.info("Prompt caching optimization settings applied", settings=settings)
class LLMCaller:
"""
@@ -1085,16 +1224,28 @@ class LLMCaller:
except Exception as e:
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
input_tokens = response.get("usage", {}).get("prompt_tokens", 0)
output_tokens = response.get("usage", {}).get("completion_tokens", 0)
input_tokens = 0
output_tokens = 0
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
if hasattr(response, "usage") and response.usage:
input_tokens = getattr(response.usage, "prompt_tokens", 0)
output_tokens = getattr(response.usage, "completion_tokens", 0)
# Extract reasoning tokens from completion_tokens_details
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
if completion_token_detail:
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
# Extract cached tokens from prompt_tokens_details
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
if cached_token_detail:
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
if cached_tokens == 0:
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=input_tokens,