Fix stale litellm httpx client errors in Temporal workers (#SKY-7879) (#4676)

This commit is contained in:
pedrohsdb
2026-02-09 18:31:36 -08:00
committed by GitHub
parent c17418692a
commit ed9e37ee48

View File

@@ -56,6 +56,55 @@ EXTRACT_ACTION_DEFAULT_THINKING_BUDGET = settings.EXTRACT_ACTION_THINKING_BUDGET
DEFAULT_THINKING_BUDGET = settings.DEFAULT_THINKING_BUDGET
def _is_stale_client_error(exc: BaseException) -> bool:
"""Check if an exception chain contains the httpx 'client has been closed' RuntimeError.
litellm caches AsyncOpenAI/AsyncAzureOpenAI clients keyed by event loop ID.
In long-lived Temporal worker pods the event loop can change (task completion,
activity recycling), leaving stale clients whose underlying httpx.AsyncClient
is closed. Subsequent requests through those clients raise:
RuntimeError: Cannot send a request, as the client has been closed
which litellm wraps as APIConnectionError or InternalServerError.
"""
cur: BaseException | None = exc
while cur is not None:
if isinstance(cur, RuntimeError) and "client has been closed" in str(cur):
return True
next_exc = cur.__cause__ or cur.__context__
cur = next_exc if next_exc is not cur else None
return False
async def _acompletion_with_stale_client_retry(
acompletion_callable: Any,
**kwargs: Any,
) -> Any:
"""Call *acompletion_callable* and retry once if the failure is a stale httpx client.
On first failure the entire ``litellm.in_memory_llm_clients_cache`` is flushed
so that the retry creates a fresh client bound to the current event loop.
"""
try:
return await acompletion_callable(**kwargs)
except Exception as first_err:
if not _is_stale_client_error(first_err):
raise
model = kwargs.get("model", "unknown")
LOG.warning(
"Stale httpx client detected flushing litellm client cache and retrying",
error_type=type(first_err).__name__,
model=model,
)
try:
litellm.in_memory_llm_clients_cache.flush_cache()
except Exception:
LOG.warning("Failed to flush litellm client cache", exc_info=True)
# Retry once with a fresh client
return await acompletion_callable(**kwargs)
def _safe_model_dump_json(response: ModelResponse, indent: int = 2) -> str:
"""
Call model_dump_json() while suppressing Pydantic serialization warnings.
@@ -609,7 +658,8 @@ class LLMAPIHandlerFactory:
cache_variant=cache_variant_name,
)
request_payload_json = await _log_llm_request_artifact(request_model, True)
response = await litellm.acompletion(
response = await _acompletion_with_stale_client_retry(
litellm.acompletion,
model=request_model,
messages=active_messages,
timeout=settings.LLM_CONFIG_TIMEOUT,
@@ -620,7 +670,8 @@ class LLMAPIHandlerFactory:
async def _call_router_without_cache() -> tuple[ModelResponse, str]:
request_payload_json = await _log_llm_request_artifact(llm_key, False)
response = await router.acompletion(
response = await _acompletion_with_stale_client_retry(
router.acompletion,
model=main_model_group,
messages=messages,
timeout=settings.LLM_CONFIG_TIMEOUT,
@@ -1041,7 +1092,8 @@ class LLMAPIHandlerFactory:
try:
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
response = await litellm.acompletion(
response = await _acompletion_with_stale_client_retry(
litellm.acompletion,
model=model_name,
messages=active_messages,
drop_params=True, # Drop unsupported parameters gracefully
@@ -1627,7 +1679,8 @@ class LLMCaller:
if self.llm_key and "UI_TARS" in self.llm_key:
return await self._call_ui_tars(messages, tools, timeout, **active_parameters)
return await litellm.acompletion(
return await _acompletion_with_stale_client_retry(
litellm.acompletion,
model=self.llm_config.model_name,
messages=messages,
tools=tools,