cleaned up fallback router (#4010)

This commit is contained in:
pedrohsdb
2025-11-17 12:08:19 -08:00
committed by GitHub
parent abcdf6a033
commit d1c7c675cf

View File

@@ -53,6 +53,26 @@ class LLMAPIHandlerFactory:
_thinking_budget_settings: dict[str, int] | None = None
_prompt_caching_settings: dict[str, bool] | None = None
@staticmethod
def _models_equivalent(left: str | None, right: str | None) -> bool:
"""Used only by `llm_api_handler_with_router_and_fallback`. Router model
groups carry the `vertex-` prefix while LiteLLM responses return the
underlying provider label (e.g. `gemini-2.5-pro`). Stripping the prefix
lets us detect whether the configured primary (the router's
`main_model_group`) actually served the request without replumbing every
config/registry reference.
"""
if left == right:
return True
if left is None or right is None:
return False
def _normalize(label: str) -> str:
normalized = label.lower()
return normalized[len("vertex-") :] if normalized.startswith("vertex-") else normalized
return _normalize(left) == _normalize(right)
@staticmethod
def _apply_thinking_budget_optimization(
parameters: dict[str, Any], new_budget: int, llm_config: LLMConfig | LLMRouterConfig, prompt_name: str
@@ -349,10 +369,21 @@ class LLMAPIHandlerFactory:
thought=thought,
ai_suggestion=ai_suggestion,
)
model_used = main_model_group
try:
response = await router.acompletion(
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
)
response_model = response.model or main_model_group
model_used = response_model
if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group):
LOG.info(
"LLM router fallback succeeded",
llm_key=llm_key,
prompt_name=prompt_name,
primary_model=main_model_group,
fallback_model=response_model,
)
except litellm.exceptions.APIError as e:
raise LLMProviderErrorRetryableTask(llm_key) from e
except litellm.exceptions.ContextWindowExceededError as e:
@@ -487,7 +518,7 @@ class LLMAPIHandlerFactory:
LOG.info(
"LLM API handler duration metrics",
llm_key=llm_key,
model=main_model_group,
model=model_used,
prompt_name=prompt_name,
duration_seconds=duration_seconds,
step_id=step.step_id if step else None,
@@ -508,7 +539,7 @@ class LLMAPIHandlerFactory:
parsed_response_json=parsed_response_json,
rendered_response_json=rendered_response_json,
llm_key=llm_key,
model=main_model_group,
model=model_used,
duration_seconds=duration_seconds,
input_tokens=prompt_tokens if prompt_tokens > 0 else None,
output_tokens=completion_tokens if completion_tokens > 0 else None,