diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index b2ae5544..6b4a1fae 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -165,6 +165,7 @@ def _convert_allowed_fails_policy(policy: LLMAllowedFailsPolicy | None) -> Allow class LLMAPIHandlerFactory: _custom_handlers: dict[str, LLMAPIHandler] = {} + _router_handler_cache: dict[str, LLMAPIHandler] = {} _thinking_budget_settings: dict[str, int] | None = None _prompt_caching_settings: dict[str, bool] | None = None @@ -389,6 +390,9 @@ class LLMAPIHandlerFactory: @staticmethod def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler: + if llm_key in LLMAPIHandlerFactory._router_handler_cache: + return LLMAPIHandlerFactory._router_handler_cache[llm_key] + llm_config = LLMConfigRegistry.get_config(llm_key) if not isinstance(llm_config, LLMRouterConfig): raise InvalidLLMConfigError(llm_key) @@ -831,6 +835,7 @@ class LLMAPIHandlerFactory: LOG.error("Failed to persist artifacts", exc_info=True) llm_api_handler_with_router_and_fallback.llm_key = llm_key # type: ignore[attr-defined] + LLMAPIHandlerFactory._router_handler_cache[llm_key] = llm_api_handler_with_router_and_fallback return llm_api_handler_with_router_and_fallback @staticmethod