add gpt4o mini support (#666)
This commit is contained in:
@@ -75,6 +75,7 @@ class Settings(BaseSettings):
|
|||||||
ENABLE_OPENAI: bool = False
|
ENABLE_OPENAI: bool = False
|
||||||
ENABLE_ANTHROPIC: bool = False
|
ENABLE_ANTHROPIC: bool = False
|
||||||
ENABLE_AZURE: bool = False
|
ENABLE_AZURE: bool = False
|
||||||
|
ENABLE_AZURE_GPT4O_MINI: bool = False
|
||||||
ENABLE_BEDROCK: bool = False
|
ENABLE_BEDROCK: bool = False
|
||||||
# OPENAI
|
# OPENAI
|
||||||
OPENAI_API_KEY: str | None = None
|
OPENAI_API_KEY: str | None = None
|
||||||
@@ -86,6 +87,12 @@ class Settings(BaseSettings):
|
|||||||
AZURE_API_BASE: str | None = None
|
AZURE_API_BASE: str | None = None
|
||||||
AZURE_API_VERSION: str | None = None
|
AZURE_API_VERSION: str | None = None
|
||||||
|
|
||||||
|
# AZURE GPT-4o mini
|
||||||
|
AZURE_GPT4O_MINI_DEPLOYMENT: str | None = None
|
||||||
|
AZURE_GPT4O_MINI_API_KEY: str | None = None
|
||||||
|
AZURE_GPT4O_MINI_API_BASE: str | None = None
|
||||||
|
AZURE_GPT4O_MINI_API_VERSION: str | None = None
|
||||||
|
|
||||||
def is_cloud_environment(self) -> bool:
|
def is_cloud_environment(self) -> bool:
|
||||||
"""
|
"""
|
||||||
:return: True if env is not local, else False
|
:return: True if env is not local, else False
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
|
|||||||
InvalidLLMConfigError,
|
InvalidLLMConfigError,
|
||||||
LLMProviderError,
|
LLMProviderError,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig
|
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig
|
||||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
from skyvern.forge.sdk.models import Step
|
from skyvern.forge.sdk.models import Step
|
||||||
@@ -147,6 +147,8 @@ class LLMAPIHandlerFactory:
|
|||||||
if LLMConfigRegistry.is_router_config(llm_key):
|
if LLMConfigRegistry.is_router_config(llm_key):
|
||||||
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
|
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
|
||||||
|
|
||||||
|
assert isinstance(llm_config, LLMConfig)
|
||||||
|
|
||||||
async def llm_api_handler(
|
async def llm_api_handler(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
step: Step | None = None,
|
step: Step | None = None,
|
||||||
@@ -158,6 +160,8 @@ class LLMAPIHandlerFactory:
|
|||||||
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
||||||
|
|
||||||
active_parameters.update(parameters)
|
active_parameters.update(parameters)
|
||||||
|
if llm_config.litellm_params: # type: ignore
|
||||||
|
active_parameters.update(llm_config.litellm_params) # type: ignore
|
||||||
|
|
||||||
if step:
|
if step:
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
|
|||||||
MissingLLMProviderEnvVarsError,
|
MissingLLMProviderEnvVarsError,
|
||||||
NoProviderEnabledError,
|
NoProviderEnabledError,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig
|
from skyvern.forge.sdk.api.llm.models import LiteLLMParams, LLMConfig, LLMRouterConfig
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
@@ -49,6 +49,7 @@ if not any(
|
|||||||
SettingsManager.get_settings().ENABLE_OPENAI,
|
SettingsManager.get_settings().ENABLE_OPENAI,
|
||||||
SettingsManager.get_settings().ENABLE_ANTHROPIC,
|
SettingsManager.get_settings().ENABLE_ANTHROPIC,
|
||||||
SettingsManager.get_settings().ENABLE_AZURE,
|
SettingsManager.get_settings().ENABLE_AZURE,
|
||||||
|
SettingsManager.get_settings().ENABLE_AZURE_GPT4O_MINI,
|
||||||
SettingsManager.get_settings().ENABLE_BEDROCK,
|
SettingsManager.get_settings().ENABLE_BEDROCK,
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
@@ -189,3 +190,24 @@ if SettingsManager.get_settings().ENABLE_AZURE:
|
|||||||
add_assistant_prefix=False,
|
add_assistant_prefix=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SettingsManager.get_settings().ENABLE_AZURE_GPT4O_MINI:
|
||||||
|
LLMConfigRegistry.register_config(
|
||||||
|
"AZURE_OPENAI_GPT4O_MINI",
|
||||||
|
LLMConfig(
|
||||||
|
f"azure/{SettingsManager.get_settings().AZURE_GPT4O_MINI_DEPLOYMENT}",
|
||||||
|
[
|
||||||
|
"AZURE_GPT4O_MINI_DEPLOYMENT",
|
||||||
|
"AZURE_GPT4O_MINI_API_KEY",
|
||||||
|
"AZURE_GPT4O_MINI_API_BASE",
|
||||||
|
"AZURE_GPT4O_MINI_API_VERSION",
|
||||||
|
],
|
||||||
|
litellm_params=LiteLLMParams(
|
||||||
|
api_base=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_BASE,
|
||||||
|
api_key=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_KEY,
|
||||||
|
api_version=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_VERSION,
|
||||||
|
),
|
||||||
|
supports_vision=True,
|
||||||
|
add_assistant_prefix=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Awaitable, Literal, Protocol
|
from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict
|
||||||
|
|
||||||
from skyvern.forge.sdk.models import Step
|
from skyvern.forge.sdk.models import Step
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMParams(TypedDict):
|
||||||
|
api_key: str | None
|
||||||
|
api_version: str | None
|
||||||
|
api_base: str | None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class LLMConfig:
|
class LLMConfigBase:
|
||||||
model_name: str
|
model_name: str
|
||||||
required_env_vars: list[str]
|
required_env_vars: list[str]
|
||||||
supports_vision: bool
|
supports_vision: bool
|
||||||
@@ -22,6 +28,11 @@ class LLMConfig:
|
|||||||
return missing_env_vars
|
return missing_env_vars
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LLMConfig(LLMConfigBase):
|
||||||
|
litellm_params: Optional[LiteLLMParams] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class LLMRouterModelConfig:
|
class LLMRouterModelConfig:
|
||||||
model_name: str
|
model_name: str
|
||||||
@@ -33,7 +44,7 @@ class LLMRouterModelConfig:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class LLMRouterConfig(LLMConfig):
|
class LLMRouterConfig(LLMConfigBase):
|
||||||
model_list: list[LLMRouterModelConfig]
|
model_list: list[LLMRouterModelConfig]
|
||||||
# All three redis parameters are required. Even if there isn't a password, it should be an empty string.
|
# All three redis parameters are required. Even if there isn't a password, it should be an empty string.
|
||||||
main_model_group: str
|
main_model_group: str
|
||||||
|
|||||||
Reference in New Issue
Block a user