fix(llm): strip static prompt from cached Vertex AI requests to preve… (#4321)

This commit is contained in:
pedrohsdb
2025-12-17 17:25:36 -08:00
committed by GitHub
parent eed17a6b9d
commit f594474b9e

View File

@@ -149,6 +149,40 @@ class LLMAPIHandlerFactory:
_thinking_budget_settings: dict[str, int] | None = None
_prompt_caching_settings: dict[str, bool] | None = None
@staticmethod
def _strip_static_prompt_from_messages(messages: list[dict[str, Any]], static_prompt: str) -> bool:
"""
Strips the static prompt from the first matching user message in the list.
Returns True if the prompt was found and stripped, False otherwise.
This handles both string content and list-based content (e.g. for vision models).
The static prompt is right-stripped to handle trailing newlines from templates.
The remaining dynamic content is left-stripped to handle connector whitespace.
"""
static_text = static_prompt.rstrip()
prompt_stripped = False
for msg in messages:
if msg.get("role") == "user":
content = msg.get("content")
if isinstance(content, str):
if content.startswith(static_text):
msg["content"] = content[len(static_text) :].lstrip()
prompt_stripped = True
break
elif isinstance(content, list):
for block in content:
if block.get("type") == "text":
text = block.get("text", "")
if text.startswith(static_text):
block["text"] = text[len(static_text) :].lstrip()
prompt_stripped = True
break
if prompt_stripped:
break
return prompt_stripped
@staticmethod
def _models_equivalent(left: str | None, right: str | None) -> bool:
"""Used only by `llm_api_handler_with_router_and_fallback`. Router model
@@ -513,6 +547,22 @@ class LLMAPIHandlerFactory:
active_params.update(parameters)
active_params["cached_content"] = cache_name
request_model = active_params.pop("model", primary_model_dict.get("model_name", main_model_group))
# Clone messages to avoid modifying original list which is needed for fallback
active_messages = copy.deepcopy(messages)
# Strip static prompt from the request messages because it's already in the cache
# Sending it again causes double-billing (once cached, once uncached)
if context and context.cached_static_prompt:
prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages(
active_messages, context.cached_static_prompt
)
if prompt_stripped:
LOG.info("Stripped static prompt from cached request to avoid double-billing")
else:
LOG.warning("Could not find static prompt to strip from cached request")
LOG.info(
"Adding Vertex AI cache reference to primary Gemini request",
prompt_name=prompt_name,
@@ -525,7 +575,7 @@ class LLMAPIHandlerFactory:
request_payload_json = await _log_llm_request_artifact(request_model, True)
response = await litellm.acompletion(
model=request_model,
messages=messages,
messages=active_messages,
timeout=settings.LLM_CONFIG_TIMEOUT,
drop_params=True,
**active_params,
@@ -913,13 +963,27 @@ class LLMAPIHandlerFactory:
**artifact_targets,
)
# Strip static prompt from the request messages because it's already in the cache
# Sending it again causes double-billing (once cached, once uncached)
active_messages = messages
if vertex_cache_attached and context and context.cached_static_prompt:
active_messages = copy.deepcopy(messages)
prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages(
active_messages, context.cached_static_prompt
)
if prompt_stripped:
LOG.info("Stripped static prompt from cached request to avoid double-billing")
else:
LOG.warning("Could not find static prompt to strip from cached request")
t_llm_request = time.perf_counter()
try:
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
response = await litellm.acompletion(
model=model_name,
messages=messages,
messages=active_messages,
drop_params=True, # Drop unsupported parameters gracefully
**active_parameters,
)