diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 23ea0195..635d63c5 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -10,6 +10,7 @@ import structlog from anthropic import NOT_GIVEN from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage from jinja2 import Template +from litellm.types.router import AllowedFailsPolicy from litellm.utils import CustomStreamWrapper, ModelResponse from openai import AsyncOpenAI from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -26,7 +27,13 @@ from skyvern.forge.sdk.api.llm.exceptions import ( LLMProviderError, LLMProviderErrorRetryableTask, ) -from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig, dummy_llm_api_handler +from skyvern.forge.sdk.api.llm.models import ( + LLMAllowedFailsPolicy, + LLMAPIHandler, + LLMConfig, + LLMRouterConfig, + dummy_llm_api_handler, +) from skyvern.forge.sdk.api.llm.ui_tars_response import UITarsResponse from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, llm_messages_builder_with_history, parse_api_response from skyvern.forge.sdk.artifact.models import ArtifactType @@ -103,6 +110,20 @@ def _log_vertex_cache_hit_if_needed( ) +def _convert_allowed_fails_policy(policy: LLMAllowedFailsPolicy | None) -> AllowedFailsPolicy | None: + if policy is None: + return None + + return AllowedFailsPolicy( + BadRequestErrorAllowedFails=policy.bad_request_error_allowed_fails, + AuthenticationErrorAllowedFails=policy.authentication_error_allowed_fails, + TimeoutErrorAllowedFails=policy.timeout_error_allowed_fails, + RateLimitErrorAllowedFails=policy.rate_limit_error_allowed_fails, + ContentPolicyViolationErrorAllowedFails=policy.content_policy_violation_error_allowed_fails, + InternalServerErrorAllowedFails=policy.internal_server_error_allowed_fails, + ) + + class LLMAPIHandlerFactory: _custom_handlers: dict[str, LLMAPIHandler] = {} _thinking_budget_settings: dict[str, int] | None = None @@ -310,7 +331,7 @@ class LLMAPIHandlerFactory: retry_after=llm_config.retry_delay_seconds, disable_cooldowns=llm_config.disable_cooldowns, allowed_fails=llm_config.allowed_fails, - allowed_fails_policy=llm_config.allowed_fails_policy, + allowed_fails_policy=_convert_allowed_fails_policy(llm_config.allowed_fails_policy), cooldown_time=llm_config.cooldown_time, set_verbose=(False if settings.is_cloud_environment() else llm_config.set_verbose), enable_pre_call_checks=True, diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 45274c57..b60273be 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -1,8 +1,6 @@ from dataclasses import dataclass, field from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict -from litellm import AllowedFailsPolicy - from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.task_v2 import TaskV2, Thought @@ -48,6 +46,16 @@ class LLMConfig(LLMConfigBase): reasoning_effort: str | None = None +@dataclass(frozen=True) +class LLMAllowedFailsPolicy: + bad_request_error_allowed_fails: int | None = None + authentication_error_allowed_fails: int | None = None + timeout_error_allowed_fails: int | None = None + rate_limit_error_allowed_fails: int | None = None + content_policy_violation_error_allowed_fails: int | None = None + internal_server_error_allowed_fails: int | None = None + + @dataclass(frozen=True) class LLMRouterModelConfig: model_name: str @@ -79,7 +87,7 @@ class LLMRouterConfig(LLMConfigBase): set_verbose: bool = False disable_cooldowns: bool | None = None allowed_fails: int | None = None - allowed_fails_policy: AllowedFailsPolicy | None = None + allowed_fails_policy: LLMAllowedFailsPolicy | None = None cooldown_time: float | None = None max_tokens: int | None = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS max_completion_tokens: int | None = None diff --git a/skyvern/webeye/actions/actions.py b/skyvern/webeye/actions/actions.py index 573fa898..3e525f3a 100644 --- a/skyvern/webeye/actions/actions.py +++ b/skyvern/webeye/actions/actions.py @@ -3,8 +3,7 @@ from enum import StrEnum from typing import Annotated, Any, Literal, Type, TypeVar import structlog -from litellm import ConfigDict -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from skyvern.errors.errors import UserDefinedError from skyvern.webeye.actions.action_types import ActionType