diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index b2ae5544..64c4b59b 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -56,6 +56,55 @@ EXTRACT_ACTION_DEFAULT_THINKING_BUDGET = settings.EXTRACT_ACTION_THINKING_BUDGET DEFAULT_THINKING_BUDGET = settings.DEFAULT_THINKING_BUDGET +def _is_stale_client_error(exc: BaseException) -> bool: + """Check if an exception chain contains the httpx 'client has been closed' RuntimeError. + + litellm caches AsyncOpenAI/AsyncAzureOpenAI clients keyed by event loop ID. + In long-lived Temporal worker pods the event loop can change (task completion, + activity recycling), leaving stale clients whose underlying httpx.AsyncClient + is closed. Subsequent requests through those clients raise: + RuntimeError: Cannot send a request, as the client has been closed + which litellm wraps as APIConnectionError or InternalServerError. + """ + cur: BaseException | None = exc + while cur is not None: + if isinstance(cur, RuntimeError) and "client has been closed" in str(cur): + return True + next_exc = cur.__cause__ or cur.__context__ + cur = next_exc if next_exc is not cur else None + return False + + +async def _acompletion_with_stale_client_retry( + acompletion_callable: Any, + **kwargs: Any, +) -> Any: + """Call *acompletion_callable* and retry once if the failure is a stale httpx client. + + On first failure the entire ``litellm.in_memory_llm_clients_cache`` is flushed + so that the retry creates a fresh client bound to the current event loop. + """ + try: + return await acompletion_callable(**kwargs) + except Exception as first_err: + if not _is_stale_client_error(first_err): + raise + + model = kwargs.get("model", "unknown") + LOG.warning( + "Stale httpx client detected – flushing litellm client cache and retrying", + error_type=type(first_err).__name__, + model=model, + ) + try: + litellm.in_memory_llm_clients_cache.flush_cache() + except Exception: + LOG.warning("Failed to flush litellm client cache", exc_info=True) + + # Retry once with a fresh client + return await acompletion_callable(**kwargs) + + def _safe_model_dump_json(response: ModelResponse, indent: int = 2) -> str: """ Call model_dump_json() while suppressing Pydantic serialization warnings. @@ -609,7 +658,8 @@ class LLMAPIHandlerFactory: cache_variant=cache_variant_name, ) request_payload_json = await _log_llm_request_artifact(request_model, True) - response = await litellm.acompletion( + response = await _acompletion_with_stale_client_retry( + litellm.acompletion, model=request_model, messages=active_messages, timeout=settings.LLM_CONFIG_TIMEOUT, @@ -620,7 +670,8 @@ class LLMAPIHandlerFactory: async def _call_router_without_cache() -> tuple[ModelResponse, str]: request_payload_json = await _log_llm_request_artifact(llm_key, False) - response = await router.acompletion( + response = await _acompletion_with_stale_client_retry( + router.acompletion, model=main_model_group, messages=messages, timeout=settings.LLM_CONFIG_TIMEOUT, @@ -1041,7 +1092,8 @@ class LLMAPIHandlerFactory: try: # TODO (kerem): add a retry mechanism to this call (acompletion_with_retries) # TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work - response = await litellm.acompletion( + response = await _acompletion_with_stale_client_retry( + litellm.acompletion, model=model_name, messages=active_messages, drop_params=True, # Drop unsupported parameters gracefully @@ -1627,7 +1679,8 @@ class LLMCaller: if self.llm_key and "UI_TARS" in self.llm_key: return await self._call_ui_tars(messages, tools, timeout, **active_parameters) - return await litellm.acompletion( + return await _acompletion_with_stale_client_retry( + litellm.acompletion, model=self.llm_config.model_name, messages=messages, tools=tools,