From f594474b9e711d55d80e2f8901d9a1ec5361ccb9 Mon Sep 17 00:00:00 2001 From: pedrohsdb Date: Wed, 17 Dec 2025 17:25:36 -0800 Subject: [PATCH] =?UTF-8?q?fix(llm):=20strip=20static=20prompt=20from=20ca?= =?UTF-8?q?ched=20Vertex=20AI=20requests=20to=20preve=E2=80=A6=20(#4321)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../forge/sdk/api/llm/api_handler_factory.py | 68 ++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 5dbb7741..d658ee0f 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -149,6 +149,40 @@ class LLMAPIHandlerFactory: _thinking_budget_settings: dict[str, int] | None = None _prompt_caching_settings: dict[str, bool] | None = None + @staticmethod + def _strip_static_prompt_from_messages(messages: list[dict[str, Any]], static_prompt: str) -> bool: + """ + Strips the static prompt from the first matching user message in the list. + Returns True if the prompt was found and stripped, False otherwise. + + This handles both string content and list-based content (e.g. for vision models). + The static prompt is right-stripped to handle trailing newlines from templates. + The remaining dynamic content is left-stripped to handle connector whitespace. + """ + static_text = static_prompt.rstrip() + prompt_stripped = False + + for msg in messages: + if msg.get("role") == "user": + content = msg.get("content") + if isinstance(content, str): + if content.startswith(static_text): + msg["content"] = content[len(static_text) :].lstrip() + prompt_stripped = True + break + elif isinstance(content, list): + for block in content: + if block.get("type") == "text": + text = block.get("text", "") + if text.startswith(static_text): + block["text"] = text[len(static_text) :].lstrip() + prompt_stripped = True + break + if prompt_stripped: + break + + return prompt_stripped + @staticmethod def _models_equivalent(left: str | None, right: str | None) -> bool: """Used only by `llm_api_handler_with_router_and_fallback`. Router model @@ -513,6 +547,22 @@ class LLMAPIHandlerFactory: active_params.update(parameters) active_params["cached_content"] = cache_name request_model = active_params.pop("model", primary_model_dict.get("model_name", main_model_group)) + + # Clone messages to avoid modifying original list which is needed for fallback + active_messages = copy.deepcopy(messages) + + # Strip static prompt from the request messages because it's already in the cache + # Sending it again causes double-billing (once cached, once uncached) + if context and context.cached_static_prompt: + prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages( + active_messages, context.cached_static_prompt + ) + + if prompt_stripped: + LOG.info("Stripped static prompt from cached request to avoid double-billing") + else: + LOG.warning("Could not find static prompt to strip from cached request") + LOG.info( "Adding Vertex AI cache reference to primary Gemini request", prompt_name=prompt_name, @@ -525,7 +575,7 @@ class LLMAPIHandlerFactory: request_payload_json = await _log_llm_request_artifact(request_model, True) response = await litellm.acompletion( model=request_model, - messages=messages, + messages=active_messages, timeout=settings.LLM_CONFIG_TIMEOUT, drop_params=True, **active_params, @@ -913,13 +963,27 @@ class LLMAPIHandlerFactory: **artifact_targets, ) + # Strip static prompt from the request messages because it's already in the cache + # Sending it again causes double-billing (once cached, once uncached) + active_messages = messages + if vertex_cache_attached and context and context.cached_static_prompt: + active_messages = copy.deepcopy(messages) + prompt_stripped = LLMAPIHandlerFactory._strip_static_prompt_from_messages( + active_messages, context.cached_static_prompt + ) + + if prompt_stripped: + LOG.info("Stripped static prompt from cached request to avoid double-billing") + else: + LOG.warning("Could not find static prompt to strip from cached request") + t_llm_request = time.perf_counter() try: # TODO (kerem): add a retry mechanism to this call (acompletion_with_retries) # TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work response = await litellm.acompletion( model=model_name, - messages=messages, + messages=active_messages, drop_params=True, # Drop unsupported parameters gracefully **active_parameters, )