add gpt4o mini support (#666)

This commit is contained in:
LawyZheng
2024-08-02 19:35:52 +08:00
committed by GitHub
parent a9f52c4dbb
commit 98e2f7f206
4 changed files with 49 additions and 5 deletions

View File

@@ -75,6 +75,7 @@ class Settings(BaseSettings):
ENABLE_OPENAI: bool = False ENABLE_OPENAI: bool = False
ENABLE_ANTHROPIC: bool = False ENABLE_ANTHROPIC: bool = False
ENABLE_AZURE: bool = False ENABLE_AZURE: bool = False
ENABLE_AZURE_GPT4O_MINI: bool = False
ENABLE_BEDROCK: bool = False ENABLE_BEDROCK: bool = False
# OPENAI # OPENAI
OPENAI_API_KEY: str | None = None OPENAI_API_KEY: str | None = None
@@ -86,6 +87,12 @@ class Settings(BaseSettings):
AZURE_API_BASE: str | None = None AZURE_API_BASE: str | None = None
AZURE_API_VERSION: str | None = None AZURE_API_VERSION: str | None = None
# AZURE GPT-4o mini
AZURE_GPT4O_MINI_DEPLOYMENT: str | None = None
AZURE_GPT4O_MINI_API_KEY: str | None = None
AZURE_GPT4O_MINI_API_BASE: str | None = None
AZURE_GPT4O_MINI_API_VERSION: str | None = None
def is_cloud_environment(self) -> bool: def is_cloud_environment(self) -> bool:
""" """
:return: True if env is not local, else False :return: True if env is not local, else False

View File

@@ -15,7 +15,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
InvalidLLMConfigError, InvalidLLMConfigError,
LLMProviderError, LLMProviderError,
) )
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
@@ -147,6 +147,8 @@ class LLMAPIHandlerFactory:
if LLMConfigRegistry.is_router_config(llm_key): if LLMConfigRegistry.is_router_config(llm_key):
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key) return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
assert isinstance(llm_config, LLMConfig)
async def llm_api_handler( async def llm_api_handler(
prompt: str, prompt: str,
step: Step | None = None, step: Step | None = None,
@@ -158,6 +160,8 @@ class LLMAPIHandlerFactory:
parameters = LLMAPIHandlerFactory.get_api_parameters() parameters = LLMAPIHandlerFactory.get_api_parameters()
active_parameters.update(parameters) active_parameters.update(parameters)
if llm_config.litellm_params: # type: ignore
active_parameters.update(llm_config.litellm_params) # type: ignore
if step: if step:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(

View File

@@ -6,7 +6,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
MissingLLMProviderEnvVarsError, MissingLLMProviderEnvVarsError,
NoProviderEnabledError, NoProviderEnabledError,
) )
from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig from skyvern.forge.sdk.api.llm.models import LiteLLMParams, LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger() LOG = structlog.get_logger()
@@ -49,6 +49,7 @@ if not any(
SettingsManager.get_settings().ENABLE_OPENAI, SettingsManager.get_settings().ENABLE_OPENAI,
SettingsManager.get_settings().ENABLE_ANTHROPIC, SettingsManager.get_settings().ENABLE_ANTHROPIC,
SettingsManager.get_settings().ENABLE_AZURE, SettingsManager.get_settings().ENABLE_AZURE,
SettingsManager.get_settings().ENABLE_AZURE_GPT4O_MINI,
SettingsManager.get_settings().ENABLE_BEDROCK, SettingsManager.get_settings().ENABLE_BEDROCK,
] ]
): ):
@@ -189,3 +190,24 @@ if SettingsManager.get_settings().ENABLE_AZURE:
add_assistant_prefix=False, add_assistant_prefix=False,
), ),
) )
if SettingsManager.get_settings().ENABLE_AZURE_GPT4O_MINI:
LLMConfigRegistry.register_config(
"AZURE_OPENAI_GPT4O_MINI",
LLMConfig(
f"azure/{SettingsManager.get_settings().AZURE_GPT4O_MINI_DEPLOYMENT}",
[
"AZURE_GPT4O_MINI_DEPLOYMENT",
"AZURE_GPT4O_MINI_API_KEY",
"AZURE_GPT4O_MINI_API_BASE",
"AZURE_GPT4O_MINI_API_VERSION",
],
litellm_params=LiteLLMParams(
api_base=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_BASE,
api_key=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_KEY,
api_version=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_VERSION,
),
supports_vision=True,
add_assistant_prefix=False,
),
)

View File

@@ -1,12 +1,18 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Awaitable, Literal, Protocol from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
class LiteLLMParams(TypedDict):
api_key: str | None
api_version: str | None
api_base: str | None
@dataclass(frozen=True) @dataclass(frozen=True)
class LLMConfig: class LLMConfigBase:
model_name: str model_name: str
required_env_vars: list[str] required_env_vars: list[str]
supports_vision: bool supports_vision: bool
@@ -22,6 +28,11 @@ class LLMConfig:
return missing_env_vars return missing_env_vars
@dataclass(frozen=True)
class LLMConfig(LLMConfigBase):
litellm_params: Optional[LiteLLMParams] = field(default=None)
@dataclass(frozen=True) @dataclass(frozen=True)
class LLMRouterModelConfig: class LLMRouterModelConfig:
model_name: str model_name: str
@@ -33,7 +44,7 @@ class LLMRouterModelConfig:
@dataclass(frozen=True) @dataclass(frozen=True)
class LLMRouterConfig(LLMConfig): class LLMRouterConfig(LLMConfigBase):
model_list: list[LLMRouterModelConfig] model_list: list[LLMRouterModelConfig]
# All three redis parameters are required. Even if there isn't a password, it should be an empty string. # All three redis parameters are required. Even if there isn't a password, it should be an empty string.
main_model_group: str main_model_group: str