add gpt4o mini support (#666)

This commit is contained in:
LawyZheng
2024-08-02 19:35:52 +08:00
committed by GitHub
parent a9f52c4dbb
commit 98e2f7f206
4 changed files with 49 additions and 5 deletions

View File

@@ -75,6 +75,7 @@ class Settings(BaseSettings):
ENABLE_OPENAI: bool = False
ENABLE_ANTHROPIC: bool = False
ENABLE_AZURE: bool = False
ENABLE_AZURE_GPT4O_MINI: bool = False
ENABLE_BEDROCK: bool = False
# OPENAI
OPENAI_API_KEY: str | None = None
@@ -86,6 +87,12 @@ class Settings(BaseSettings):
AZURE_API_BASE: str | None = None
AZURE_API_VERSION: str | None = None
# AZURE GPT-4o mini
AZURE_GPT4O_MINI_DEPLOYMENT: str | None = None
AZURE_GPT4O_MINI_API_KEY: str | None = None
AZURE_GPT4O_MINI_API_BASE: str | None = None
AZURE_GPT4O_MINI_API_VERSION: str | None = None
def is_cloud_environment(self) -> bool:
"""
:return: True if env is not local, else False

View File

@@ -15,7 +15,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
InvalidLLMConfigError,
LLMProviderError,
)
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
@@ -147,6 +147,8 @@ class LLMAPIHandlerFactory:
if LLMConfigRegistry.is_router_config(llm_key):
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)
assert isinstance(llm_config, LLMConfig)
async def llm_api_handler(
prompt: str,
step: Step | None = None,
@@ -158,6 +160,8 @@ class LLMAPIHandlerFactory:
parameters = LLMAPIHandlerFactory.get_api_parameters()
active_parameters.update(parameters)
if llm_config.litellm_params: # type: ignore
active_parameters.update(llm_config.litellm_params) # type: ignore
if step:
await app.ARTIFACT_MANAGER.create_artifact(

View File

@@ -6,7 +6,7 @@ from skyvern.forge.sdk.api.llm.exceptions import (
MissingLLMProviderEnvVarsError,
NoProviderEnabledError,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.api.llm.models import LiteLLMParams, LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
@@ -49,6 +49,7 @@ if not any(
SettingsManager.get_settings().ENABLE_OPENAI,
SettingsManager.get_settings().ENABLE_ANTHROPIC,
SettingsManager.get_settings().ENABLE_AZURE,
SettingsManager.get_settings().ENABLE_AZURE_GPT4O_MINI,
SettingsManager.get_settings().ENABLE_BEDROCK,
]
):
@@ -189,3 +190,24 @@ if SettingsManager.get_settings().ENABLE_AZURE:
add_assistant_prefix=False,
),
)
if SettingsManager.get_settings().ENABLE_AZURE_GPT4O_MINI:
LLMConfigRegistry.register_config(
"AZURE_OPENAI_GPT4O_MINI",
LLMConfig(
f"azure/{SettingsManager.get_settings().AZURE_GPT4O_MINI_DEPLOYMENT}",
[
"AZURE_GPT4O_MINI_DEPLOYMENT",
"AZURE_GPT4O_MINI_API_KEY",
"AZURE_GPT4O_MINI_API_BASE",
"AZURE_GPT4O_MINI_API_VERSION",
],
litellm_params=LiteLLMParams(
api_base=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_BASE,
api_key=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_KEY,
api_version=SettingsManager.get_settings().AZURE_GPT4O_MINI_API_VERSION,
),
supports_vision=True,
add_assistant_prefix=False,
),
)

View File

@@ -1,12 +1,18 @@
from dataclasses import dataclass, field
from typing import Any, Awaitable, Literal, Protocol
from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
class LiteLLMParams(TypedDict):
api_key: str | None
api_version: str | None
api_base: str | None
@dataclass(frozen=True)
class LLMConfig:
class LLMConfigBase:
model_name: str
required_env_vars: list[str]
supports_vision: bool
@@ -22,6 +28,11 @@ class LLMConfig:
return missing_env_vars
@dataclass(frozen=True)
class LLMConfig(LLMConfigBase):
litellm_params: Optional[LiteLLMParams] = field(default=None)
@dataclass(frozen=True)
class LLMRouterModelConfig:
model_name: str
@@ -33,7 +44,7 @@ class LLMRouterModelConfig:
@dataclass(frozen=True)
class LLMRouterConfig(LLMConfig):
class LLMRouterConfig(LLMConfigBase):
model_list: list[LLMRouterModelConfig]
# All three redis parameters are required. Even if there isn't a password, it should be an empty string.
main_model_group: str