Pedro/prompt caching (#3531)
This commit is contained in:
@@ -48,6 +48,7 @@ class LLMCallStats(BaseModel):
|
||||
class LLMAPIHandlerFactory:
|
||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||
_thinking_budget_settings: dict[str, int] | None = None
|
||||
_prompt_caching_settings: dict[str, bool] | None = None
|
||||
|
||||
@staticmethod
|
||||
def _apply_thinking_budget_optimization(
|
||||
@@ -270,8 +271,62 @@ class LLMAPIHandlerFactory:
|
||||
task_v2=task_v2,
|
||||
thought=thought,
|
||||
)
|
||||
# Build messages and apply caching in one step
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
|
||||
# Inject context caching system message when available
|
||||
try:
|
||||
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
|
||||
if (
|
||||
context_cached_static_prompt
|
||||
and isinstance(llm_config, LLMConfig)
|
||||
and isinstance(llm_config.model_name, str)
|
||||
):
|
||||
# Check if this is a Vertex AI model
|
||||
if "vertex_ai/" in llm_config.model_name:
|
||||
caching_system_message = {
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": context_cached_static_prompt,
|
||||
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
|
||||
}
|
||||
],
|
||||
}
|
||||
messages = [caching_system_message] + messages
|
||||
LOG.info(
|
||||
"Applied Vertex context caching",
|
||||
prompt_name=prompt_name,
|
||||
model=llm_config.model_name,
|
||||
ttl_seconds=3600,
|
||||
)
|
||||
# Check if this is an OpenAI model
|
||||
elif (
|
||||
llm_config.model_name.startswith("gpt-")
|
||||
or llm_config.model_name.startswith("o1-")
|
||||
or llm_config.model_name.startswith("o3-")
|
||||
):
|
||||
# For OpenAI models, we need to add the cached content as a system message
|
||||
# and mark it for caching using the cache_control parameter
|
||||
caching_system_message = {
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": context_cached_static_prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
messages = [caching_system_message] + messages
|
||||
LOG.info(
|
||||
"Applied OpenAI context caching",
|
||||
prompt_name=prompt_name,
|
||||
model=llm_config.model_name,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(
|
||||
{
|
||||
@@ -343,16 +398,28 @@ class LLMAPIHandlerFactory:
|
||||
except Exception as e:
|
||||
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||
llm_cost = 0
|
||||
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
reasoning_tokens = 0
|
||||
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
|
||||
cached_tokens = 0
|
||||
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
|
||||
if cached_token_detail:
|
||||
cached_tokens = cached_token_detail.cached_tokens or 0
|
||||
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
prompt_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
completion_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
|
||||
# Extract reasoning tokens from completion_tokens_details
|
||||
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
|
||||
|
||||
# Extract cached tokens from prompt_tokens_details
|
||||
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
|
||||
if cached_token_detail:
|
||||
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
|
||||
|
||||
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
|
||||
if cached_tokens == 0:
|
||||
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
if step:
|
||||
await app.DATABASE.update_step(
|
||||
task_id=step.task_id,
|
||||
@@ -492,6 +559,59 @@ class LLMAPIHandlerFactory:
|
||||
model_name = llm_config.model_name
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
|
||||
# Inject context caching system message when available
|
||||
try:
|
||||
context_cached_static_prompt = getattr(context, "cached_static_prompt", None)
|
||||
if (
|
||||
context_cached_static_prompt
|
||||
and isinstance(llm_config, LLMConfig)
|
||||
and isinstance(llm_config.model_name, str)
|
||||
):
|
||||
# Check if this is a Vertex AI model
|
||||
if "vertex_ai/" in llm_config.model_name:
|
||||
caching_system_message = {
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": context_cached_static_prompt,
|
||||
"cache_control": {"type": "ephemeral", "ttl": "3600s"},
|
||||
}
|
||||
],
|
||||
}
|
||||
messages = [caching_system_message] + messages
|
||||
LOG.info(
|
||||
"Applied Vertex context caching",
|
||||
prompt_name=prompt_name,
|
||||
model=llm_config.model_name,
|
||||
ttl_seconds=3600,
|
||||
)
|
||||
# Check if this is an OpenAI model
|
||||
elif (
|
||||
llm_config.model_name.startswith("gpt-")
|
||||
or llm_config.model_name.startswith("o1-")
|
||||
or llm_config.model_name.startswith("o3-")
|
||||
):
|
||||
# For OpenAI models, we need to add the cached content as a system message
|
||||
# and mark it for caching using the cache_control parameter
|
||||
caching_system_message = {
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": context_cached_static_prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
messages = [caching_system_message] + messages
|
||||
LOG.info(
|
||||
"Applied OpenAI context caching",
|
||||
prompt_name=prompt_name,
|
||||
model=llm_config.model_name,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning("Failed to apply context caching system message", error=str(e), exc_info=True)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(
|
||||
{
|
||||
@@ -573,16 +693,28 @@ class LLMAPIHandlerFactory:
|
||||
except Exception as e:
|
||||
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||
llm_cost = 0
|
||||
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
reasoning_tokens = 0
|
||||
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
|
||||
cached_tokens = 0
|
||||
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
|
||||
if cached_token_detail:
|
||||
cached_tokens = cached_token_detail.cached_tokens or 0
|
||||
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
prompt_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
completion_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
|
||||
# Extract reasoning tokens from completion_tokens_details
|
||||
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
|
||||
|
||||
# Extract cached tokens from prompt_tokens_details
|
||||
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
|
||||
if cached_token_detail:
|
||||
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
|
||||
|
||||
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
|
||||
if cached_tokens == 0:
|
||||
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
|
||||
if step:
|
||||
await app.DATABASE.update_step(
|
||||
@@ -684,6 +816,13 @@ class LLMAPIHandlerFactory:
|
||||
if settings:
|
||||
LOG.info("Thinking budget optimization settings applied", settings=settings)
|
||||
|
||||
@classmethod
|
||||
def set_prompt_caching_settings(cls, settings: dict[str, bool] | None) -> None:
|
||||
"""Set prompt caching optimization settings for the current task/workflow."""
|
||||
cls._prompt_caching_settings = settings
|
||||
if settings:
|
||||
LOG.info("Prompt caching optimization settings applied", settings=settings)
|
||||
|
||||
|
||||
class LLMCaller:
|
||||
"""
|
||||
@@ -1085,16 +1224,28 @@ class LLMCaller:
|
||||
except Exception as e:
|
||||
LOG.info("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||
llm_cost = 0
|
||||
input_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||
output_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
reasoning_tokens = 0
|
||||
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
|
||||
cached_tokens = 0
|
||||
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
|
||||
if cached_token_detail:
|
||||
cached_tokens = cached_token_detail.cached_tokens or 0
|
||||
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
input_tokens = getattr(response.usage, "prompt_tokens", 0)
|
||||
output_tokens = getattr(response.usage, "completion_tokens", 0)
|
||||
|
||||
# Extract reasoning tokens from completion_tokens_details
|
||||
completion_token_detail = getattr(response.usage, "completion_tokens_details", None)
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = getattr(completion_token_detail, "reasoning_tokens", 0) or 0
|
||||
|
||||
# Extract cached tokens from prompt_tokens_details
|
||||
cached_token_detail = getattr(response.usage, "prompt_tokens_details", None)
|
||||
if cached_token_detail:
|
||||
cached_tokens = getattr(cached_token_detail, "cached_tokens", 0) or 0
|
||||
|
||||
# Fallback for Vertex/Gemini: LiteLLM exposes cache_read_input_tokens on usage
|
||||
if cached_tokens == 0:
|
||||
cached_tokens = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=input_tokens,
|
||||
|
||||
Reference in New Issue
Block a user