cleaned up fallback router (#4010)

This commit is contained in:
pedrohsdb
2025-11-17 12:08:19 -08:00
committed by GitHub
parent abcdf6a033
commit d1c7c675cf

View File

@@ -53,6 +53,26 @@ class LLMAPIHandlerFactory:
_thinking_budget_settings: dict[str, int] | None = None _thinking_budget_settings: dict[str, int] | None = None
_prompt_caching_settings: dict[str, bool] | None = None _prompt_caching_settings: dict[str, bool] | None = None
@staticmethod
def _models_equivalent(left: str | None, right: str | None) -> bool:
"""Used only by `llm_api_handler_with_router_and_fallback`. Router model
groups carry the `vertex-` prefix while LiteLLM responses return the
underlying provider label (e.g. `gemini-2.5-pro`). Stripping the prefix
lets us detect whether the configured primary (the router's
`main_model_group`) actually served the request without replumbing every
config/registry reference.
"""
if left == right:
return True
if left is None or right is None:
return False
def _normalize(label: str) -> str:
normalized = label.lower()
return normalized[len("vertex-") :] if normalized.startswith("vertex-") else normalized
return _normalize(left) == _normalize(right)
@staticmethod @staticmethod
def _apply_thinking_budget_optimization( def _apply_thinking_budget_optimization(
parameters: dict[str, Any], new_budget: int, llm_config: LLMConfig | LLMRouterConfig, prompt_name: str parameters: dict[str, Any], new_budget: int, llm_config: LLMConfig | LLMRouterConfig, prompt_name: str
@@ -349,10 +369,21 @@ class LLMAPIHandlerFactory:
thought=thought, thought=thought,
ai_suggestion=ai_suggestion, ai_suggestion=ai_suggestion,
) )
model_used = main_model_group
try: try:
response = await router.acompletion( response = await router.acompletion(
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
) )
response_model = response.model or main_model_group
model_used = response_model
if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group):
LOG.info(
"LLM router fallback succeeded",
llm_key=llm_key,
prompt_name=prompt_name,
primary_model=main_model_group,
fallback_model=response_model,
)
except litellm.exceptions.APIError as e: except litellm.exceptions.APIError as e:
raise LLMProviderErrorRetryableTask(llm_key) from e raise LLMProviderErrorRetryableTask(llm_key) from e
except litellm.exceptions.ContextWindowExceededError as e: except litellm.exceptions.ContextWindowExceededError as e:
@@ -487,7 +518,7 @@ class LLMAPIHandlerFactory:
LOG.info( LOG.info(
"LLM API handler duration metrics", "LLM API handler duration metrics",
llm_key=llm_key, llm_key=llm_key,
model=main_model_group, model=model_used,
prompt_name=prompt_name, prompt_name=prompt_name,
duration_seconds=duration_seconds, duration_seconds=duration_seconds,
step_id=step.step_id if step else None, step_id=step.step_id if step else None,
@@ -508,7 +539,7 @@ class LLMAPIHandlerFactory:
parsed_response_json=parsed_response_json, parsed_response_json=parsed_response_json,
rendered_response_json=rendered_response_json, rendered_response_json=rendered_response_json,
llm_key=llm_key, llm_key=llm_key,
model=main_model_group, model=model_used,
duration_seconds=duration_seconds, duration_seconds=duration_seconds,
input_tokens=prompt_tokens if prompt_tokens > 0 else None, input_tokens=prompt_tokens if prompt_tokens > 0 else None,
output_tokens=completion_tokens if completion_tokens > 0 else None, output_tokens=completion_tokens if completion_tokens > 0 else None,