From d1c7c675cfb77a85be8d506004b6daeb494df647 Mon Sep 17 00:00:00 2001 From: pedrohsdb Date: Mon, 17 Nov 2025 12:08:19 -0800 Subject: [PATCH] cleaned up fallback router (#4010) --- .../forge/sdk/api/llm/api_handler_factory.py | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index d29b3c67..547ccfe2 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -53,6 +53,26 @@ class LLMAPIHandlerFactory: _thinking_budget_settings: dict[str, int] | None = None _prompt_caching_settings: dict[str, bool] | None = None + @staticmethod + def _models_equivalent(left: str | None, right: str | None) -> bool: + """Used only by `llm_api_handler_with_router_and_fallback`. Router model + groups carry the `vertex-` prefix while LiteLLM responses return the + underlying provider label (e.g. `gemini-2.5-pro`). Stripping the prefix + lets us detect whether the configured primary (the router's + `main_model_group`) actually served the request without replumbing every + config/registry reference. + """ + if left == right: + return True + if left is None or right is None: + return False + + def _normalize(label: str) -> str: + normalized = label.lower() + return normalized[len("vertex-") :] if normalized.startswith("vertex-") else normalized + + return _normalize(left) == _normalize(right) + @staticmethod def _apply_thinking_budget_optimization( parameters: dict[str, Any], new_budget: int, llm_config: LLMConfig | LLMRouterConfig, prompt_name: str @@ -349,10 +369,21 @@ class LLMAPIHandlerFactory: thought=thought, ai_suggestion=ai_suggestion, ) + model_used = main_model_group try: response = await router.acompletion( model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters ) + response_model = response.model or main_model_group + model_used = response_model + if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group): + LOG.info( + "LLM router fallback succeeded", + llm_key=llm_key, + prompt_name=prompt_name, + primary_model=main_model_group, + fallback_model=response_model, + ) except litellm.exceptions.APIError as e: raise LLMProviderErrorRetryableTask(llm_key) from e except litellm.exceptions.ContextWindowExceededError as e: @@ -487,7 +518,7 @@ class LLMAPIHandlerFactory: LOG.info( "LLM API handler duration metrics", llm_key=llm_key, - model=main_model_group, + model=model_used, prompt_name=prompt_name, duration_seconds=duration_seconds, step_id=step.step_id if step else None, @@ -508,7 +539,7 @@ class LLMAPIHandlerFactory: parsed_response_json=parsed_response_json, rendered_response_json=rendered_response_json, llm_key=llm_key, - model=main_model_group, + model=model_used, duration_seconds=duration_seconds, input_tokens=prompt_tokens if prompt_tokens > 0 else None, output_tokens=completion_tokens if completion_tokens > 0 else None,