Implement LLM router (#95)

This commit is contained in:
Kerem Yilmaz
2024-03-16 23:13:18 -07:00
committed by GitHub
parent 0e34bfa2bd
commit d1de19556e
16 changed files with 485 additions and 308 deletions

View File

@@ -1,27 +0,0 @@
from typing import Callable
from pydantic import BaseModel
openai_model_to_price_lambdas = {
"gpt-4-vision-preview": (0.01, 0.03),
"gpt-4-1106-preview": (0.01, 0.03),
"gpt-4-0125-preview": (0.01, 0.03),
"gpt-3.5-turbo": (0.001, 0.002),
"gpt-3.5-turbo-1106": (0.001, 0.002),
"gpt-3.5-turbo-0125": (0.0005, 0.0015),
}
class ChatCompletionPrice(BaseModel):
input_token_count: int
output_token_count: int
openai_model_to_price_lambda: Callable[[int, int], float]
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
super().__init__(
input_token_count=input_token_count,
output_token_count=output_token_count,
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
+ output_token_price * output_token / 1000,
)

View File

View File

@@ -0,0 +1,115 @@
import json
from typing import Any
import litellm
import openai
import structlog
from skyvern.forge import app
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}
@staticmethod
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key)
async def llm_api_handler(
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters()
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
if not llm_config.supports_vision:
screenshots = None
messages = await llm_messages_builder(prompt, screenshots)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)
try:
# TODO (kerem): add a timeout to this call
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
response = await litellm.acompletion(
model=llm_config.model_name,
messages=messages,
**parameters,
)
except openai.OpenAIError as e:
raise LLMProviderError(llm_key) from e
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
)
parsed_response = parse_api_response(response)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
return llm_api_handler
@staticmethod
def get_api_parameters() -> dict[str, Any]:
return {
"max_tokens": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS,
"temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE,
}
@classmethod
def register_custom_handler(cls, llm_key: str, handler: LLMAPIHandler) -> None:
if llm_key in cls._custom_handlers:
raise DuplicateCustomLLMProviderError(llm_key)
cls._custom_handlers[llm_key] = handler

View File

@@ -0,0 +1,70 @@
import structlog
from skyvern.forge.sdk.api.llm.exceptions import (
DuplicateLLMConfigError,
InvalidLLMConfigError,
MissingLLMProviderEnvVarsError,
NoProviderEnabledError,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LLMConfigRegistry:
_configs: dict[str, LLMConfig] = {}
@staticmethod
def validate_config(llm_key: str, config: LLMConfig) -> None:
missing_env_vars = config.get_missing_env_vars()
if missing_env_vars:
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
@classmethod
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
if llm_key in cls._configs:
raise DuplicateLLMConfigError(llm_key)
cls.validate_config(llm_key, config)
LOG.info("Registering LLM config", llm_key=llm_key)
cls._configs[llm_key] = config
@classmethod
def get_config(cls, llm_key: str) -> LLMConfig:
if llm_key not in cls._configs:
raise InvalidLLMConfigError(llm_key)
return cls._configs[llm_key]
# if none of the LLM providers are enabled, raise an error
if not any(
[
SettingsManager.get_settings().ENABLE_OPENAI,
SettingsManager.get_settings().ENABLE_ANTHROPIC,
SettingsManager.get_settings().ENABLE_AZURE,
]
):
raise NoProviderEnabledError()
if SettingsManager.get_settings().ENABLE_OPENAI:
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo-preview", ["OPENAI_API_KEY"], False))
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-vision-preview", ["OPENAI_API_KEY"], True))
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
)
if SettingsManager.get_settings().ENABLE_AZURE:
LLMConfigRegistry.register_config(
"AZURE_OPENAI_GPT4V",
LLMConfig(
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
True,
),
)

View File

@@ -0,0 +1,48 @@
from skyvern.exceptions import SkyvernException
class BaseLLMError(SkyvernException):
pass
class MissingLLMProviderEnvVarsError(BaseLLMError):
def __init__(self, llm_key: str, missing_env_vars: list[str]) -> None:
super().__init__(f"Environment variables {','.join(missing_env_vars)} are required for LLMProvider {llm_key}")
class EmptyLLMResponseError(BaseLLMError):
def __init__(self, response: str) -> None:
super().__init__(f"LLM response content is empty: {response}")
class InvalidLLMResponseFormat(BaseLLMError):
def __init__(self, response: str) -> None:
super().__init__(f"LLM response content is not a valid JSON: {response}")
class DuplicateCustomLLMProviderError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"Custom LLMProvider {llm_key} is already registered")
class DuplicateLLMConfigError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"LLM config with key {llm_key} is already registered")
class InvalidLLMConfigError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"LLM config with key {llm_key} is not a valid LLMConfig")
class LLMProviderError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"Error while using LLMProvider {llm_key}")
class NoProviderEnabledError(BaseLLMError):
def __init__(self) -> None:
super().__init__(
"At least one LLM provider must be enabled. Run setup.sh and follow through the LLM provider setup, or "
"update the .env file (check out .env.example to see the required environment variables)."
)

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import Any, Awaitable, Protocol
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
@dataclass(frozen=True)
class LLMConfig:
model_name: str
required_env_vars: list[str]
supports_vision: bool
def get_missing_env_vars(self) -> list[str]:
missing_env_vars = []
for env_var in self.required_env_vars:
env_var_value = getattr(SettingsManager.get_settings(), env_var, None)
if not env_var_value:
missing_env_vars.append(env_var)
return missing_env_vars
class LLMAPIHandler(Protocol):
def __call__(
self,
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]:
...

View File

@@ -0,0 +1,45 @@
import base64
from typing import Any
import commentjson
import litellm
from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat
async def llm_messages_builder(
prompt: str,
screenshots: list[bytes] | None = None,
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = [
{
"type": "text",
"text": prompt,
}
]
if screenshots:
for screenshot in screenshots:
encoded_image = base64.b64encode(screenshot).decode("utf-8")
messages.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_image}",
},
}
)
return [{"role": "user", "content": messages}]
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
try:
content = response.choices[0].message.content
content = content.replace("```json", "")
content = content.replace("```", "")
if not content:
raise EmptyLLMResponseError(str(response))
return commentjson.loads(content)
except Exception as e:
raise InvalidLLMResponseFormat(str(response)) from e

View File

@@ -1,228 +0,0 @@
import base64
import json
import random
from datetime import datetime, timedelta
from typing import Any
import commentjson
import openai
import structlog
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import ChatCompletion
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
from skyvern.forge import app
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class OpenAIKeyClientWrapper:
client: AsyncOpenAI
key: str
remaining_requests: int | None
def __init__(self, key: str, remaining_requests: int | None) -> None:
self.key = key
self.remaining_requests = remaining_requests
self.updated_at = datetime.utcnow()
self.client = AsyncOpenAI(api_key=self.key)
def update_remaining_requests(self, remaining_requests: int | None) -> None:
self.remaining_requests = remaining_requests
self.updated_at = datetime.utcnow()
def is_available(self) -> bool:
# If remaining_requests is None, then it's the first time we're trying this key
# so we can assume it's available, otherwise we check if it's greater than 0
if self.remaining_requests is None:
return True
if self.remaining_requests > 0:
return True
# If we haven't checked this in over 1 minutes, check it again
# Most of our failures are because of Tokens-per-minute (TPM) limits
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
return True
return False
class OpenAIClientManager:
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
clients: list[OpenAIKeyClientWrapper]
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
available_clients = [client for client in self.clients if client.is_available()]
if not available_clients:
return None
# Randomly select an available client to distribute requests across our accounts
return random.choice(available_clients)
async def content_builder(
self,
step: Step,
screenshots: list[bytes] | None = None,
prompt: str | None = None,
) -> list[dict[str, Any]]:
content: list[dict[str, Any]] = []
if prompt is not None:
content.append(
{
"type": "text",
"text": prompt,
}
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
if screenshots:
for screenshot in screenshots:
encoded_image = base64.b64encode(screenshot).decode("utf-8")
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_image}",
},
}
)
# create artifact for each image
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
return content
async def chat_completion(
self,
step: Step,
model: str = "gpt-4-vision-preview",
max_tokens: int = 4096,
temperature: int = 0,
screenshots: list[bytes] | None = None,
prompt: str | None = None,
) -> dict[str, Any]:
LOG.info(
f"Sending LLM request",
task_id=step.task_id,
step_id=step.step_id,
num_screenshots=len(screenshots) if screenshots else 0,
)
messages = [
{
"role": "user",
"content": await self.content_builder(
step=step,
screenshots=screenshots,
prompt=prompt,
),
}
]
chat_completion_kwargs = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
)
available_client = self.get_available_client()
if available_client is None:
raise NoAvailableOpenAIClients()
try:
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
except openai.RateLimitError as e:
# If we get a RateLimitError, we can assume the key is not available anymore
if e.code == 429:
raise OpenAIRequestTooBigError(e.message)
LOG.warning(
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
)
available_client.update_remaining_requests(remaining_requests=0)
available_client = self.get_available_client()
if available_client is None:
raise NoAvailableOpenAIClients()
return await self.chat_completion(
step=step,
model=model,
max_tokens=max_tokens,
temperature=temperature,
screenshots=screenshots,
prompt=prompt,
)
except openai.OpenAIError as e:
LOG.error("OpenAI error", exc_info=True)
raise e
except Exception as e:
LOG.error("Unknown error for chat completion", error_message=str(e), error_type=type(e))
raise e
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
# If we get a response, we can assume the key is available and update the remaining requests
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
if not ratelimit_remaining_requests:
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
chat_completion = response.parse()
if chat_completion.usage is not None:
# TODO (Suchintan): Is this bad design?
step = await app.DATABASE.update_step(
step_id=step.step_id,
task_id=step.task_id,
organization_id=step.organization_id,
chat_completion_price=ChatCompletionPrice(
input_token_count=chat_completion.usage.prompt_tokens,
output_token_count=chat_completion.usage.completion_tokens,
model_name=model,
),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
)
parsed_response = self.parse_response(chat_completion)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
try:
content = response.choices[0].message.content
content = content.replace("```json", "")
content = content.replace("```", "")
if not content:
raise Exception("openai response content is empty")
return commentjson.loads(content)
except Exception as e:
raise InvalidOpenAIResponseFormat(str(response)) from e