Implement LLM router (#95)

This commit is contained in:
Kerem Yilmaz
2024-03-16 23:13:18 -07:00
committed by GitHub
parent 0e34bfa2bd
commit d1de19556e
16 changed files with 485 additions and 308 deletions

View File

View File

@@ -0,0 +1,115 @@
import json
from typing import Any
import litellm
import openai
import structlog
from skyvern.forge import app
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}
@staticmethod
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key)
async def llm_api_handler(
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters()
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
if not llm_config.supports_vision:
screenshots = None
messages = await llm_messages_builder(prompt, screenshots)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)
try:
# TODO (kerem): add a timeout to this call
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
response = await litellm.acompletion(
model=llm_config.model_name,
messages=messages,
**parameters,
)
except openai.OpenAIError as e:
raise LLMProviderError(llm_key) from e
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
)
parsed_response = parse_api_response(response)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
return llm_api_handler
@staticmethod
def get_api_parameters() -> dict[str, Any]:
return {
"max_tokens": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS,
"temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE,
}
@classmethod
def register_custom_handler(cls, llm_key: str, handler: LLMAPIHandler) -> None:
if llm_key in cls._custom_handlers:
raise DuplicateCustomLLMProviderError(llm_key)
cls._custom_handlers[llm_key] = handler

View File

@@ -0,0 +1,70 @@
import structlog
from skyvern.forge.sdk.api.llm.exceptions import (
DuplicateLLMConfigError,
InvalidLLMConfigError,
MissingLLMProviderEnvVarsError,
NoProviderEnabledError,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LLMConfigRegistry:
_configs: dict[str, LLMConfig] = {}
@staticmethod
def validate_config(llm_key: str, config: LLMConfig) -> None:
missing_env_vars = config.get_missing_env_vars()
if missing_env_vars:
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
@classmethod
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
if llm_key in cls._configs:
raise DuplicateLLMConfigError(llm_key)
cls.validate_config(llm_key, config)
LOG.info("Registering LLM config", llm_key=llm_key)
cls._configs[llm_key] = config
@classmethod
def get_config(cls, llm_key: str) -> LLMConfig:
if llm_key not in cls._configs:
raise InvalidLLMConfigError(llm_key)
return cls._configs[llm_key]
# if none of the LLM providers are enabled, raise an error
if not any(
[
SettingsManager.get_settings().ENABLE_OPENAI,
SettingsManager.get_settings().ENABLE_ANTHROPIC,
SettingsManager.get_settings().ENABLE_AZURE,
]
):
raise NoProviderEnabledError()
if SettingsManager.get_settings().ENABLE_OPENAI:
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo-preview", ["OPENAI_API_KEY"], False))
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-vision-preview", ["OPENAI_API_KEY"], True))
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
)
if SettingsManager.get_settings().ENABLE_AZURE:
LLMConfigRegistry.register_config(
"AZURE_OPENAI_GPT4V",
LLMConfig(
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
True,
),
)

View File

@@ -0,0 +1,48 @@
from skyvern.exceptions import SkyvernException
class BaseLLMError(SkyvernException):
pass
class MissingLLMProviderEnvVarsError(BaseLLMError):
def __init__(self, llm_key: str, missing_env_vars: list[str]) -> None:
super().__init__(f"Environment variables {','.join(missing_env_vars)} are required for LLMProvider {llm_key}")
class EmptyLLMResponseError(BaseLLMError):
def __init__(self, response: str) -> None:
super().__init__(f"LLM response content is empty: {response}")
class InvalidLLMResponseFormat(BaseLLMError):
def __init__(self, response: str) -> None:
super().__init__(f"LLM response content is not a valid JSON: {response}")
class DuplicateCustomLLMProviderError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"Custom LLMProvider {llm_key} is already registered")
class DuplicateLLMConfigError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"LLM config with key {llm_key} is already registered")
class InvalidLLMConfigError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"LLM config with key {llm_key} is not a valid LLMConfig")
class LLMProviderError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"Error while using LLMProvider {llm_key}")
class NoProviderEnabledError(BaseLLMError):
def __init__(self) -> None:
super().__init__(
"At least one LLM provider must be enabled. Run setup.sh and follow through the LLM provider setup, or "
"update the .env file (check out .env.example to see the required environment variables)."
)

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import Any, Awaitable, Protocol
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
@dataclass(frozen=True)
class LLMConfig:
model_name: str
required_env_vars: list[str]
supports_vision: bool
def get_missing_env_vars(self) -> list[str]:
missing_env_vars = []
for env_var in self.required_env_vars:
env_var_value = getattr(SettingsManager.get_settings(), env_var, None)
if not env_var_value:
missing_env_vars.append(env_var)
return missing_env_vars
class LLMAPIHandler(Protocol):
def __call__(
self,
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]:
...

View File

@@ -0,0 +1,45 @@
import base64
from typing import Any
import commentjson
import litellm
from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat
async def llm_messages_builder(
prompt: str,
screenshots: list[bytes] | None = None,
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = [
{
"type": "text",
"text": prompt,
}
]
if screenshots:
for screenshot in screenshots:
encoded_image = base64.b64encode(screenshot).decode("utf-8")
messages.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_image}",
},
}
)
return [{"role": "user", "content": messages}]
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
try:
content = response.choices[0].message.content
content = content.replace("```json", "")
content = content.replace("```", "")
if not content:
raise EmptyLLMResponseError(str(response))
return commentjson.loads(content)
except Exception as e:
raise InvalidLLMResponseFormat(str(response)) from e