cleaned up fallback router (#4010)
This commit is contained in:
@@ -53,6 +53,26 @@ class LLMAPIHandlerFactory:
|
||||
_thinking_budget_settings: dict[str, int] | None = None
|
||||
_prompt_caching_settings: dict[str, bool] | None = None
|
||||
|
||||
@staticmethod
|
||||
def _models_equivalent(left: str | None, right: str | None) -> bool:
|
||||
"""Used only by `llm_api_handler_with_router_and_fallback`. Router model
|
||||
groups carry the `vertex-` prefix while LiteLLM responses return the
|
||||
underlying provider label (e.g. `gemini-2.5-pro`). Stripping the prefix
|
||||
lets us detect whether the configured primary (the router's
|
||||
`main_model_group`) actually served the request without replumbing every
|
||||
config/registry reference.
|
||||
"""
|
||||
if left == right:
|
||||
return True
|
||||
if left is None or right is None:
|
||||
return False
|
||||
|
||||
def _normalize(label: str) -> str:
|
||||
normalized = label.lower()
|
||||
return normalized[len("vertex-") :] if normalized.startswith("vertex-") else normalized
|
||||
|
||||
return _normalize(left) == _normalize(right)
|
||||
|
||||
@staticmethod
|
||||
def _apply_thinking_budget_optimization(
|
||||
parameters: dict[str, Any], new_budget: int, llm_config: LLMConfig | LLMRouterConfig, prompt_name: str
|
||||
@@ -349,10 +369,21 @@ class LLMAPIHandlerFactory:
|
||||
thought=thought,
|
||||
ai_suggestion=ai_suggestion,
|
||||
)
|
||||
model_used = main_model_group
|
||||
try:
|
||||
response = await router.acompletion(
|
||||
model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, **parameters
|
||||
)
|
||||
response_model = response.model or main_model_group
|
||||
model_used = response_model
|
||||
if not LLMAPIHandlerFactory._models_equivalent(response_model, main_model_group):
|
||||
LOG.info(
|
||||
"LLM router fallback succeeded",
|
||||
llm_key=llm_key,
|
||||
prompt_name=prompt_name,
|
||||
primary_model=main_model_group,
|
||||
fallback_model=response_model,
|
||||
)
|
||||
except litellm.exceptions.APIError as e:
|
||||
raise LLMProviderErrorRetryableTask(llm_key) from e
|
||||
except litellm.exceptions.ContextWindowExceededError as e:
|
||||
@@ -487,7 +518,7 @@ class LLMAPIHandlerFactory:
|
||||
LOG.info(
|
||||
"LLM API handler duration metrics",
|
||||
llm_key=llm_key,
|
||||
model=main_model_group,
|
||||
model=model_used,
|
||||
prompt_name=prompt_name,
|
||||
duration_seconds=duration_seconds,
|
||||
step_id=step.step_id if step else None,
|
||||
@@ -508,7 +539,7 @@ class LLMAPIHandlerFactory:
|
||||
parsed_response_json=parsed_response_json,
|
||||
rendered_response_json=rendered_response_json,
|
||||
llm_key=llm_key,
|
||||
model=main_model_group,
|
||||
model=model_used,
|
||||
duration_seconds=duration_seconds,
|
||||
input_tokens=prompt_tokens if prompt_tokens > 0 else None,
|
||||
output_tokens=completion_tokens if completion_tokens > 0 else None,
|
||||
|
||||
Reference in New Issue
Block a user