fix(llm): strip static prompt from cached Vertex AI requests to preve… (#4321)
This commit is contained in:
@@ -149,6 +149,40 @@ class LLMAPIHandlerFactory:
|
||||
_thinking_budget_settings: dict[str, int] | None = None
|
||||
_prompt_caching_settings: dict[str, bool] | None = None
|
||||
|
||||
@staticmethod
|
||||
def _strip_static_prompt_from_messages(messages: list[dict[str, Any]], static_prompt: str) -> bool:
|
||||
"""
|
||||
Strips the static prompt from the first matching user message in the list.
|
||||
Returns True if the prompt was found and stripped, False otherwise.
|
||||
|
||||
This handles both string content and list-based content (e.g. for vision models).
|
||||
The static prompt is right-stripped to handle trailing newlines from templates.
|
||||
The remaining dynamic content is left-stripped to handle connector whitespace.
|
||||
"""
|
||||
static_text = static_prompt.rstrip()
|
||||
prompt_stripped = False
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
if content.startswith(static_text):
|
||||
msg["content"] = content[len(static_text) :].lstrip()
|
||||
prompt_stripped = True
|
||||
break
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if text.startswith(static_text):
|
||||
block["text"] = text[len(static_text) :].lstrip()
|
||||
prompt_stripped = True
|
||||
break
|
||||
if prompt_stripped:
|
||||
break
|
||||
|
||||
return prompt_stripped
|
||||
|
||||
@staticmethod
|
||||
def _models_equivalent(left: str | None, right: str | None) -> bool:
|
||||
"""Used only by `llm_api_handler_with_router_and_fallback`. Router model
|
||||
@@ -513,6 +547,22 @@ class LLMAPIHandlerFactory:
|
||||
active_params.update(parameters)
|
||||
active_params["cached_content"] = cache_name
|
||||
request_model = active_params.pop("model", primary_model_dict.get("model_name", main_model_group))
|
||||
|
||||
# Clone messages to avoid modifying original list which is needed for fallback
|
||||
active_messages = copy.deepcopy(messages)
|
||||
|
||||
# Strip static prompt from the request messages because it's already in the cache
|
||||
# Sending it again causes double-billing (once cached, once uncached)
|
||||
if context and context.cached_static_prompt:
|
||||
prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages(
|
||||
active_messages, context.cached_static_prompt
|
||||
)
|
||||
|
||||
if prompt_stripped:
|
||||
LOG.info("Stripped static prompt from cached request to avoid double-billing")
|
||||
else:
|
||||
LOG.warning("Could not find static prompt to strip from cached request")
|
||||
|
||||
LOG.info(
|
||||
"Adding Vertex AI cache reference to primary Gemini request",
|
||||
prompt_name=prompt_name,
|
||||
@@ -525,7 +575,7 @@ class LLMAPIHandlerFactory:
|
||||
request_payload_json = await _log_llm_request_artifact(request_model, True)
|
||||
response = await litellm.acompletion(
|
||||
model=request_model,
|
||||
messages=messages,
|
||||
messages=active_messages,
|
||||
timeout=settings.LLM_CONFIG_TIMEOUT,
|
||||
drop_params=True,
|
||||
**active_params,
|
||||
@@ -913,13 +963,27 @@ class LLMAPIHandlerFactory:
|
||||
**artifact_targets,
|
||||
)
|
||||
|
||||
# Strip static prompt from the request messages because it's already in the cache
|
||||
# Sending it again causes double-billing (once cached, once uncached)
|
||||
active_messages = messages
|
||||
if vertex_cache_attached and context and context.cached_static_prompt:
|
||||
active_messages = copy.deepcopy(messages)
|
||||
prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages(
|
||||
active_messages, context.cached_static_prompt
|
||||
)
|
||||
|
||||
if prompt_stripped:
|
||||
LOG.info("Stripped static prompt from cached request to avoid double-billing")
|
||||
else:
|
||||
LOG.warning("Could not find static prompt to strip from cached request")
|
||||
|
||||
t_llm_request = time.perf_counter()
|
||||
try:
|
||||
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
|
||||
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
|
||||
response = await litellm.acompletion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
messages=active_messages,
|
||||
drop_params=True, # Drop unsupported parameters gracefully
|
||||
**active_parameters,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user