Implement LLM router (#95)
This commit is contained in:
0
skyvern/forge/sdk/api/llm/__init__.py
Normal file
0
skyvern/forge/sdk/api/llm/__init__.py
Normal file
115
skyvern/forge/sdk/api/llm/api_handler_factory.py
Normal file
115
skyvern/forge/sdk/api/llm/api_handler_factory.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
import openai
|
||||
import structlog
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LLMAPIHandlerFactory:
|
||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
|
||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
|
||||
async def llm_api_handler(
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if parameters is None:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
|
||||
if not llm_config.supports_vision:
|
||||
screenshots = None
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": llm_config.model_name,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
)
|
||||
try:
|
||||
# TODO (kerem): add a timeout to this call
|
||||
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
|
||||
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
|
||||
response = await litellm.acompletion(
|
||||
model=llm_config.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
raise LLMProviderError(llm_key) from e
|
||||
except Exception as e:
|
||||
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
|
||||
raise LLMProviderError(llm_key) from e
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
await app.DATABASE.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
incremental_cost=llm_cost,
|
||||
)
|
||||
parsed_response = parse_api_response(response)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
return llm_api_handler
|
||||
|
||||
@staticmethod
|
||||
def get_api_parameters() -> dict[str, Any]:
|
||||
return {
|
||||
"max_tokens": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS,
|
||||
"temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_custom_handler(cls, llm_key: str, handler: LLMAPIHandler) -> None:
|
||||
if llm_key in cls._custom_handlers:
|
||||
raise DuplicateCustomLLMProviderError(llm_key)
|
||||
cls._custom_handlers[llm_key] = handler
|
||||
70
skyvern/forge/sdk/api/llm/config_registry.py
Normal file
70
skyvern/forge/sdk/api/llm/config_registry.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.api.llm.exceptions import (
|
||||
DuplicateLLMConfigError,
|
||||
InvalidLLMConfigError,
|
||||
MissingLLMProviderEnvVarsError,
|
||||
NoProviderEnabledError,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LLMConfigRegistry:
|
||||
_configs: dict[str, LLMConfig] = {}
|
||||
|
||||
@staticmethod
|
||||
def validate_config(llm_key: str, config: LLMConfig) -> None:
|
||||
missing_env_vars = config.get_missing_env_vars()
|
||||
if missing_env_vars:
|
||||
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
|
||||
|
||||
@classmethod
|
||||
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
|
||||
if llm_key in cls._configs:
|
||||
raise DuplicateLLMConfigError(llm_key)
|
||||
|
||||
cls.validate_config(llm_key, config)
|
||||
|
||||
LOG.info("Registering LLM config", llm_key=llm_key)
|
||||
cls._configs[llm_key] = config
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, llm_key: str) -> LLMConfig:
|
||||
if llm_key not in cls._configs:
|
||||
raise InvalidLLMConfigError(llm_key)
|
||||
|
||||
return cls._configs[llm_key]
|
||||
|
||||
|
||||
# if none of the LLM providers are enabled, raise an error
|
||||
if not any(
|
||||
[
|
||||
SettingsManager.get_settings().ENABLE_OPENAI,
|
||||
SettingsManager.get_settings().ENABLE_ANTHROPIC,
|
||||
SettingsManager.get_settings().ENABLE_AZURE,
|
||||
]
|
||||
):
|
||||
raise NoProviderEnabledError()
|
||||
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_OPENAI:
|
||||
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo-preview", ["OPENAI_API_KEY"], False))
|
||||
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-vision-preview", ["OPENAI_API_KEY"], True))
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
|
||||
LLMConfigRegistry.register_config(
|
||||
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
|
||||
)
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_AZURE:
|
||||
LLMConfigRegistry.register_config(
|
||||
"AZURE_OPENAI_GPT4V",
|
||||
LLMConfig(
|
||||
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
|
||||
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
|
||||
True,
|
||||
),
|
||||
)
|
||||
48
skyvern/forge/sdk/api/llm/exceptions.py
Normal file
48
skyvern/forge/sdk/api/llm/exceptions.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from skyvern.exceptions import SkyvernException
|
||||
|
||||
|
||||
class BaseLLMError(SkyvernException):
|
||||
pass
|
||||
|
||||
|
||||
class MissingLLMProviderEnvVarsError(BaseLLMError):
|
||||
def __init__(self, llm_key: str, missing_env_vars: list[str]) -> None:
|
||||
super().__init__(f"Environment variables {','.join(missing_env_vars)} are required for LLMProvider {llm_key}")
|
||||
|
||||
|
||||
class EmptyLLMResponseError(BaseLLMError):
|
||||
def __init__(self, response: str) -> None:
|
||||
super().__init__(f"LLM response content is empty: {response}")
|
||||
|
||||
|
||||
class InvalidLLMResponseFormat(BaseLLMError):
|
||||
def __init__(self, response: str) -> None:
|
||||
super().__init__(f"LLM response content is not a valid JSON: {response}")
|
||||
|
||||
|
||||
class DuplicateCustomLLMProviderError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"Custom LLMProvider {llm_key} is already registered")
|
||||
|
||||
|
||||
class DuplicateLLMConfigError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"LLM config with key {llm_key} is already registered")
|
||||
|
||||
|
||||
class InvalidLLMConfigError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"LLM config with key {llm_key} is not a valid LLMConfig")
|
||||
|
||||
|
||||
class LLMProviderError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"Error while using LLMProvider {llm_key}")
|
||||
|
||||
|
||||
class NoProviderEnabledError(BaseLLMError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"At least one LLM provider must be enabled. Run setup.sh and follow through the LLM provider setup, or "
|
||||
"update the .env file (check out .env.example to see the required environment variables)."
|
||||
)
|
||||
32
skyvern/forge/sdk/api/llm/models.py
Normal file
32
skyvern/forge/sdk/api/llm/models.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Protocol
|
||||
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
|
||||
def get_missing_env_vars(self) -> list[str]:
|
||||
missing_env_vars = []
|
||||
for env_var in self.required_env_vars:
|
||||
env_var_value = getattr(SettingsManager.get_settings(), env_var, None)
|
||||
if not env_var_value:
|
||||
missing_env_vars.append(env_var)
|
||||
|
||||
return missing_env_vars
|
||||
|
||||
|
||||
class LLMAPIHandler(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> Awaitable[dict[str, Any]]:
|
||||
...
|
||||
45
skyvern/forge/sdk/api/llm/utils.py
Normal file
45
skyvern/forge/sdk/api/llm/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
import litellm
|
||||
|
||||
from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat
|
||||
|
||||
|
||||
async def llm_messages_builder(
|
||||
prompt: str,
|
||||
screenshots: list[bytes] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
]
|
||||
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
messages.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{encoded_image}",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
if not content:
|
||||
raise EmptyLLMResponseError(str(response))
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidLLMResponseFormat(str(response)) from e
|
||||
Reference in New Issue
Block a user