Implement LLM router (#95)
This commit is contained in:
55
.env.example
55
.env.example
@@ -1,25 +1,54 @@
|
||||
# Environment that the agent will run in.
|
||||
ENV=local
|
||||
# Your OpenAI API Keys. Separate multiple keys with commas. Keys will be used in order until the rate limit is reached for all keys
|
||||
OPENAI_API_KEYS=["abc","def","ghi"]
|
||||
|
||||
# can be either "chromium-headless" or "chromium-headful".
|
||||
# LLM Provider Configurations:
|
||||
# ENABLE_OPENAI: Set to true to enable OpenAI as a language model provider.
|
||||
ENABLE_OPENAI=false
|
||||
# OPENAI_API_KEY: Your OpenAI API key for accessing models like GPT-4.
|
||||
OPENAI_API_KEY=""
|
||||
|
||||
# ENABLE_ANTHROPIC: Set to true to enable Anthropic as a language model provider.
|
||||
ENABLE_ANTHROPIC=false
|
||||
# ANTHROPIC_API_KEY: Your Anthropic API key for accessing models like Claude-3.
|
||||
ANTHROPIC_API_KEY=""
|
||||
|
||||
# ENABLE_AZURE: Set to true to enable Azure as a language model provider.
|
||||
ENABLE_AZURE=false
|
||||
# AZURE_DEPLOYMENT: Your Azure deployment name for accessing specific models.
|
||||
AZURE_DEPLOYMENT=""
|
||||
# AZURE_API_KEY: Your API key for accessing Azure's language models.
|
||||
AZURE_API_KEY=""
|
||||
# AZURE_API_BASE: The base URL for Azure's API.
|
||||
AZURE_API_BASE=""
|
||||
# AZURE_API_VERSION: The version of Azure's API to use.
|
||||
AZURE_API_VERSION=""
|
||||
|
||||
# LLM_MODEL: The chosen language model to use. This should be one of the models
|
||||
# provided by the enabled LLM providers (e.g., OPENAI_GPT4_TURBO, OPENAI_GPT4V, ANTHROPIC_CLAUDE3, AZURE_OPENAI_GPT4V).
|
||||
LLM_MODEL=""
|
||||
|
||||
# Web browser configuration for scraping:
|
||||
# BROWSER_TYPE: Can be either "chromium-headless" or "chromium-headful".
|
||||
BROWSER_TYPE="chromium-headful"
|
||||
# number of times to retry scraping a page before giving up, currently set to 0
|
||||
# MAX_SCRAPING_RETRIES: Number of times to retry scraping a page before giving up, currently set to 0.
|
||||
MAX_SCRAPING_RETRIES=0
|
||||
# path to the directory where videos will be saved
|
||||
# VIDEO_PATH: Path to the directory where videos will be saved.
|
||||
VIDEO_PATH=./videos
|
||||
# timeout for browser actions in milliseconds
|
||||
# BROWSER_ACTION_TIMEOUT_MS: Timeout for browser actions in milliseconds.
|
||||
BROWSER_ACTION_TIMEOUT_MS=5000
|
||||
# maximum number of steps to execute per run unless the agent finishes with a terminal state (last step or error)
|
||||
MAX_STEPS_PER_RUN = 50
|
||||
|
||||
# Control log level
|
||||
# Agent run configuration:
|
||||
# MAX_STEPS_PER_RUN: Maximum number of steps to execute per run unless the agent finishes with a terminal state (last step or error).
|
||||
MAX_STEPS_PER_RUN=50
|
||||
|
||||
# Logging and database configuration:
|
||||
# LOG_LEVEL: Control log level (e.g., INFO, DEBUG).
|
||||
LOG_LEVEL=INFO
|
||||
# Database connection string
|
||||
# DATABASE_STRING: Database connection string.
|
||||
DATABASE_STRING="postgresql+psycopg://skyvern@localhost/skyvern"
|
||||
# Port to run the agent on
|
||||
# PORT: Port to run the agent on.
|
||||
PORT=8000
|
||||
|
||||
# Distinct analytics ID
|
||||
ANALYTICS_ID="anonymous"
|
||||
# Analytics configuration:
|
||||
# Distinct analytics ID (a UUID is generated if left blank).
|
||||
ANALYTICS_ID="anonymous"
|
||||
|
||||
111
setup.sh
111
setup.sh
@@ -3,7 +3,7 @@
|
||||
# Call function to send telemetry event
|
||||
log_event() {
|
||||
if [ -n $1 ]; then
|
||||
python skyvern/analytics.py $1
|
||||
poetry run python skyvern/analytics.py $1
|
||||
fi
|
||||
}
|
||||
|
||||
@@ -20,28 +20,121 @@ for cmd in poetry python3.11; do
|
||||
fi
|
||||
done
|
||||
|
||||
# Function to update or add environment variable in .env file
|
||||
update_or_add_env_var() {
|
||||
local key=$1
|
||||
local value=$2
|
||||
if grep -q "^$key=" .env; then
|
||||
# Update existing variable
|
||||
sed -i.bak "s/^$key=.*/$key=$value/" .env && rm -f .env.bak
|
||||
else
|
||||
# Add new variable
|
||||
echo "$key=$value" >> .env
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to set up LLM provider environment variables
|
||||
setup_llm_providers() {
|
||||
echo "Configuring Large Language Model (LLM) Providers..."
|
||||
echo "Note: All information provided here will be stored only on your local machine."
|
||||
local model_options=()
|
||||
|
||||
# OpenAI Configuration
|
||||
echo "To enable OpenAI, you must have an OpenAI API key."
|
||||
read -p "Do you want to enable OpenAI (y/n)? " enable_openai
|
||||
if [[ "$enable_openai" == "y" ]]; then
|
||||
read -p "Enter your OpenAI API key: " openai_api_key
|
||||
if [ -z "$openai_api_key" ]; then
|
||||
echo "Error: OpenAI API key is required."
|
||||
echo "OpenAI will not be enabled."
|
||||
else
|
||||
update_or_add_env_var "OPENAI_API_KEY" "$openai_api_key"
|
||||
update_or_add_env_var "ENABLE_OPENAI" "true"
|
||||
model_options+=("OPENAI_GPT4_TURBO" "OPENAI_GPT4V")
|
||||
fi
|
||||
else
|
||||
update_or_add_env_var "ENABLE_OPENAI" "false"
|
||||
fi
|
||||
|
||||
# Anthropic Configuration
|
||||
echo "To enable Anthropic, you must have an Anthropic API key."
|
||||
read -p "Do you want to enable Anthropic (y/n)? " enable_anthropic
|
||||
if [[ "$enable_anthropic" == "y" ]]; then
|
||||
read -p "Enter your Anthropic API key: " anthropic_api_key
|
||||
if [ -z "$anthropic_api_key" ]; then
|
||||
echo "Error: Anthropic API key is required."
|
||||
echo "Anthropic will not be enabled."
|
||||
else
|
||||
update_or_add_env_var "ANTHROPIC_API_KEY" "$anthropic_api_key"
|
||||
update_or_add_env_var "ENABLE_ANTHROPIC" "true"
|
||||
model_options+=("ANTHROPIC_CLAUDE3")
|
||||
fi
|
||||
else
|
||||
update_or_add_env_var "ENABLE_ANTHROPIC" "false"
|
||||
fi
|
||||
|
||||
# Azure Configuration
|
||||
echo "To enable Azure, you must have an Azure deployment name, API key, base URL, and API version."
|
||||
read -p "Do you want to enable Azure (y/n)? " enable_azure
|
||||
if [[ "$enable_azure" == "y" ]]; then
|
||||
read -p "Enter your Azure deployment name: " azure_deployment
|
||||
read -p "Enter your Azure API key: " azure_api_key
|
||||
read -p "Enter your Azure API base URL: " azure_api_base
|
||||
read -p "Enter your Azure API version: " azure_api_version
|
||||
if [ -z "$azure_deployment" ] || [ -z "$azure_api_key" ] || [ -z "$azure_api_base" ] || [ -z "$azure_api_version" ]; then
|
||||
echo "Error: All Azure fields must be populated."
|
||||
echo "Azure will not be enabled."
|
||||
else
|
||||
update_or_add_env_var "AZURE_DEPLOYMENT" "$azure_deployment"
|
||||
update_or_add_env_var "AZURE_API_KEY" "$azure_api_key"
|
||||
update_or_add_env_var "AZURE_API_BASE" "$azure_api_base"
|
||||
update_or_add_env_var "AZURE_API_VERSION" "$azure_api_version"
|
||||
update_or_add_env_var "ENABLE_AZURE" "true"
|
||||
model_options+=("AZURE_OPENAI_GPT4V")
|
||||
fi
|
||||
else
|
||||
update_or_add_env_var "ENABLE_AZURE" "false"
|
||||
fi
|
||||
|
||||
# Model Selection
|
||||
if [ ${#model_options[@]} -eq 0 ]; then
|
||||
echo "No LLM providers enabled. You won't be able to run Skyvern unless you enable at least one provider. You can re-run this script to enable providers or manually update the .env file."
|
||||
else
|
||||
echo "Available LLM models based on your selections:"
|
||||
for i in "${!model_options[@]}"; do
|
||||
echo "$((i+1)). ${model_options[$i]}"
|
||||
done
|
||||
read -p "Choose a model by number (e.g., 1 for ${model_options[0]}): " model_choice
|
||||
chosen_model=${model_options[$((model_choice-1))]}
|
||||
echo "Chosen LLM Model: $chosen_model"
|
||||
update_or_add_env_var "LLM_MODEL" "$chosen_model"
|
||||
fi
|
||||
|
||||
echo "LLM provider configurations updated in .env."
|
||||
}
|
||||
|
||||
|
||||
# Function to initialize .env file
|
||||
initialize_env_file() {
|
||||
if [ -f ".env" ]; then
|
||||
echo ".env file already exists, skipping initialization."
|
||||
read -p "Do you want to go through LLM provider setup again (y/n)? " redo_llm_setup
|
||||
if [[ "$redo_llm_setup" == "y" ]]; then
|
||||
setup_llm_providers
|
||||
fi
|
||||
return
|
||||
fi
|
||||
|
||||
echo "Initializing .env file..."
|
||||
cp .env.example .env
|
||||
|
||||
# Ask for OpenAI API key
|
||||
read -p "Please enter your OpenAI API key for GPT4V (this will be stored only in your local .env file): " openai_api_key
|
||||
awk -v key="$openai_api_key" '{gsub(/OPENAI_API_KEYS=\["abc","def","ghi"\]/, "OPENAI_API_KEYS=[\"" key "\"]"); print}' .env > .env.tmp && mv .env.tmp .env
|
||||
|
||||
setup_llm_providers
|
||||
|
||||
# Ask for email or generate UUID
|
||||
read -p "Please enter your email for analytics (press enter to skip): " analytics_id
|
||||
if [ -z "$analytics_id" ]; then
|
||||
analytics_id=$(uuidgen)
|
||||
fi
|
||||
awk -v id="$analytics_id" '{gsub(/ANALYTICS_ID="anonymous"/, "ANALYTICS_ID=\"" id "\""); print}' .env > .env.tmp && mv .env.tmp .env
|
||||
|
||||
update_or_add_env_var "ANALYTICS_ID" "$analytics_id"
|
||||
echo ".env file has been initialized."
|
||||
}
|
||||
|
||||
@@ -144,7 +237,7 @@ run_alembic_upgrade() {
|
||||
create_organization() {
|
||||
echo "Creating organization and API token..."
|
||||
local org_output api_token
|
||||
org_output=$(python scripts/create_organization.py Skyvern-Open-Source)
|
||||
org_output=$(poetry run python scripts/create_organization.py Skyvern-Open-Source)
|
||||
api_token=$(echo "$org_output" | awk '/token=/{gsub(/.*token='\''|'\''.*/, ""); print}')
|
||||
|
||||
# Ensure .streamlit directory exists
|
||||
|
||||
@@ -24,7 +24,6 @@ class Settings(BaseSettings):
|
||||
DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern"
|
||||
PROMPT_ACTION_HISTORY_WINDOW: int = 5
|
||||
|
||||
OPENAI_API_KEYS: list[str] = []
|
||||
ENV: str = "local"
|
||||
EXECUTE_ALL_STEPS: bool = True
|
||||
JSON_LOGGING: bool = False
|
||||
@@ -48,6 +47,28 @@ class Settings(BaseSettings):
|
||||
BROWSER_LOCALE: str = "en-US"
|
||||
BROWSER_TIMEZONE: str = "America/New_York"
|
||||
|
||||
#####################
|
||||
# LLM Configuration #
|
||||
#####################
|
||||
# ACTIVE LLM PROVIDER
|
||||
LLM_KEY: str = "OPENAI_GPT4V"
|
||||
# COMMON
|
||||
LLM_CONFIG_MAX_TOKENS: int = 4096
|
||||
LLM_CONFIG_TEMPERATURE: float = 0
|
||||
# LLM PROVIDER SPECIFIC
|
||||
ENABLE_OPENAI: bool = True
|
||||
ENABLE_ANTHROPIC: bool = False
|
||||
ENABLE_AZURE: bool = False
|
||||
# OPENAI
|
||||
OPENAI_API_KEY: str | None = None
|
||||
# ANTHROPIC
|
||||
ANTHROPIC_API_KEY: str | None = None
|
||||
# AZURE
|
||||
AZURE_DEPLOYMENT: str | None = None
|
||||
AZURE_API_KEY: str | None = None
|
||||
AZURE_API_BASE: str | None = None
|
||||
AZURE_API_VERSION: str | None = None
|
||||
|
||||
def is_cloud_environment(self) -> bool:
|
||||
"""
|
||||
:return: True if env is not local, else False
|
||||
|
||||
@@ -4,21 +4,11 @@ class SkyvernException(Exception):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class NoAvailableOpenAIClients(SkyvernException):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("No available OpenAI API clients found.")
|
||||
|
||||
|
||||
class InvalidOpenAIResponseFormat(SkyvernException):
|
||||
def __init__(self, message: str | None = None):
|
||||
super().__init__(f"Invalid response format: {message}")
|
||||
|
||||
|
||||
class OpenAIRequestTooBigError(SkyvernException):
|
||||
def __init__(self, message: str | None = None):
|
||||
super().__init__(f"OpenAI request 429 error: {message}")
|
||||
|
||||
|
||||
class FailedToSendWebhook(SkyvernException):
|
||||
def __init__(self, task_id: str | None = None, workflow_run_id: str | None = None, workflow_id: str | None = None):
|
||||
workflow_run_str = f"workflow_run_id={workflow_run_id}" if workflow_run_id else ""
|
||||
|
||||
@@ -332,9 +332,9 @@ class ForgeAgent(Agent):
|
||||
json_response = None
|
||||
actions: list[Action]
|
||||
if task.navigation_goal:
|
||||
json_response = await app.OPENAI_CLIENT.chat_completion(
|
||||
step=step,
|
||||
json_response = await app.LLM_API_HANDLER(
|
||||
prompt=extract_action_prompt,
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
)
|
||||
detailed_agent_step_output.llm_response = json_response
|
||||
|
||||
@@ -2,7 +2,7 @@ from ddtrace import tracer
|
||||
from ddtrace.filters import FilterRequestsOnUrl
|
||||
|
||||
from skyvern.forge.agent import ForgeAgent
|
||||
from skyvern.forge.sdk.api.open_ai import OpenAIClientManager
|
||||
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
@@ -28,7 +28,7 @@ DATABASE = AgentDB(
|
||||
STORAGE = StorageFactory.get_storage()
|
||||
ARTIFACT_MANAGER = ArtifactManager()
|
||||
BROWSER_MANAGER = BrowserManager()
|
||||
OPENAI_CLIENT = OpenAIClientManager()
|
||||
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
|
||||
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||
WORKFLOW_SERVICE = WorkflowService()
|
||||
agent = ForgeAgent()
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
openai_model_to_price_lambdas = {
|
||||
"gpt-4-vision-preview": (0.01, 0.03),
|
||||
"gpt-4-1106-preview": (0.01, 0.03),
|
||||
"gpt-4-0125-preview": (0.01, 0.03),
|
||||
"gpt-3.5-turbo": (0.001, 0.002),
|
||||
"gpt-3.5-turbo-1106": (0.001, 0.002),
|
||||
"gpt-3.5-turbo-0125": (0.0005, 0.0015),
|
||||
}
|
||||
|
||||
|
||||
class ChatCompletionPrice(BaseModel):
|
||||
input_token_count: int
|
||||
output_token_count: int
|
||||
openai_model_to_price_lambda: Callable[[int, int], float]
|
||||
|
||||
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
|
||||
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
|
||||
super().__init__(
|
||||
input_token_count=input_token_count,
|
||||
output_token_count=output_token_count,
|
||||
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
|
||||
+ output_token_price * output_token / 1000,
|
||||
)
|
||||
0
skyvern/forge/sdk/api/llm/__init__.py
Normal file
0
skyvern/forge/sdk/api/llm/__init__.py
Normal file
115
skyvern/forge/sdk/api/llm/api_handler_factory.py
Normal file
115
skyvern/forge/sdk/api/llm/api_handler_factory.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
import openai
|
||||
import structlog
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LLMAPIHandlerFactory:
|
||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
|
||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
|
||||
async def llm_api_handler(
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if parameters is None:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
|
||||
if not llm_config.supports_vision:
|
||||
screenshots = None
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": llm_config.model_name,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
)
|
||||
try:
|
||||
# TODO (kerem): add a timeout to this call
|
||||
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
|
||||
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
|
||||
response = await litellm.acompletion(
|
||||
model=llm_config.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
raise LLMProviderError(llm_key) from e
|
||||
except Exception as e:
|
||||
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
|
||||
raise LLMProviderError(llm_key) from e
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
await app.DATABASE.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
incremental_cost=llm_cost,
|
||||
)
|
||||
parsed_response = parse_api_response(response)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
return llm_api_handler
|
||||
|
||||
@staticmethod
|
||||
def get_api_parameters() -> dict[str, Any]:
|
||||
return {
|
||||
"max_tokens": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS,
|
||||
"temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_custom_handler(cls, llm_key: str, handler: LLMAPIHandler) -> None:
|
||||
if llm_key in cls._custom_handlers:
|
||||
raise DuplicateCustomLLMProviderError(llm_key)
|
||||
cls._custom_handlers[llm_key] = handler
|
||||
70
skyvern/forge/sdk/api/llm/config_registry.py
Normal file
70
skyvern/forge/sdk/api/llm/config_registry.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.api.llm.exceptions import (
|
||||
DuplicateLLMConfigError,
|
||||
InvalidLLMConfigError,
|
||||
MissingLLMProviderEnvVarsError,
|
||||
NoProviderEnabledError,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LLMConfigRegistry:
|
||||
_configs: dict[str, LLMConfig] = {}
|
||||
|
||||
@staticmethod
|
||||
def validate_config(llm_key: str, config: LLMConfig) -> None:
|
||||
missing_env_vars = config.get_missing_env_vars()
|
||||
if missing_env_vars:
|
||||
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
|
||||
|
||||
@classmethod
|
||||
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
|
||||
if llm_key in cls._configs:
|
||||
raise DuplicateLLMConfigError(llm_key)
|
||||
|
||||
cls.validate_config(llm_key, config)
|
||||
|
||||
LOG.info("Registering LLM config", llm_key=llm_key)
|
||||
cls._configs[llm_key] = config
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, llm_key: str) -> LLMConfig:
|
||||
if llm_key not in cls._configs:
|
||||
raise InvalidLLMConfigError(llm_key)
|
||||
|
||||
return cls._configs[llm_key]
|
||||
|
||||
|
||||
# if none of the LLM providers are enabled, raise an error
|
||||
if not any(
|
||||
[
|
||||
SettingsManager.get_settings().ENABLE_OPENAI,
|
||||
SettingsManager.get_settings().ENABLE_ANTHROPIC,
|
||||
SettingsManager.get_settings().ENABLE_AZURE,
|
||||
]
|
||||
):
|
||||
raise NoProviderEnabledError()
|
||||
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_OPENAI:
|
||||
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo-preview", ["OPENAI_API_KEY"], False))
|
||||
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-vision-preview", ["OPENAI_API_KEY"], True))
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
|
||||
LLMConfigRegistry.register_config(
|
||||
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
|
||||
)
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_AZURE:
|
||||
LLMConfigRegistry.register_config(
|
||||
"AZURE_OPENAI_GPT4V",
|
||||
LLMConfig(
|
||||
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
|
||||
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
|
||||
True,
|
||||
),
|
||||
)
|
||||
48
skyvern/forge/sdk/api/llm/exceptions.py
Normal file
48
skyvern/forge/sdk/api/llm/exceptions.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from skyvern.exceptions import SkyvernException
|
||||
|
||||
|
||||
class BaseLLMError(SkyvernException):
|
||||
pass
|
||||
|
||||
|
||||
class MissingLLMProviderEnvVarsError(BaseLLMError):
|
||||
def __init__(self, llm_key: str, missing_env_vars: list[str]) -> None:
|
||||
super().__init__(f"Environment variables {','.join(missing_env_vars)} are required for LLMProvider {llm_key}")
|
||||
|
||||
|
||||
class EmptyLLMResponseError(BaseLLMError):
|
||||
def __init__(self, response: str) -> None:
|
||||
super().__init__(f"LLM response content is empty: {response}")
|
||||
|
||||
|
||||
class InvalidLLMResponseFormat(BaseLLMError):
|
||||
def __init__(self, response: str) -> None:
|
||||
super().__init__(f"LLM response content is not a valid JSON: {response}")
|
||||
|
||||
|
||||
class DuplicateCustomLLMProviderError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"Custom LLMProvider {llm_key} is already registered")
|
||||
|
||||
|
||||
class DuplicateLLMConfigError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"LLM config with key {llm_key} is already registered")
|
||||
|
||||
|
||||
class InvalidLLMConfigError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"LLM config with key {llm_key} is not a valid LLMConfig")
|
||||
|
||||
|
||||
class LLMProviderError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"Error while using LLMProvider {llm_key}")
|
||||
|
||||
|
||||
class NoProviderEnabledError(BaseLLMError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"At least one LLM provider must be enabled. Run setup.sh and follow through the LLM provider setup, or "
|
||||
"update the .env file (check out .env.example to see the required environment variables)."
|
||||
)
|
||||
32
skyvern/forge/sdk/api/llm/models.py
Normal file
32
skyvern/forge/sdk/api/llm/models.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Protocol
|
||||
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
|
||||
def get_missing_env_vars(self) -> list[str]:
|
||||
missing_env_vars = []
|
||||
for env_var in self.required_env_vars:
|
||||
env_var_value = getattr(SettingsManager.get_settings(), env_var, None)
|
||||
if not env_var_value:
|
||||
missing_env_vars.append(env_var)
|
||||
|
||||
return missing_env_vars
|
||||
|
||||
|
||||
class LLMAPIHandler(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> Awaitable[dict[str, Any]]:
|
||||
...
|
||||
45
skyvern/forge/sdk/api/llm/utils.py
Normal file
45
skyvern/forge/sdk/api/llm/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
import litellm
|
||||
|
||||
from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat
|
||||
|
||||
|
||||
async def llm_messages_builder(
|
||||
prompt: str,
|
||||
screenshots: list[bytes] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
]
|
||||
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
messages.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{encoded_image}",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
if not content:
|
||||
raise EmptyLLMResponseError(str(response))
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidLLMResponseFormat(str(response)) from e
|
||||
@@ -1,228 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
import openai
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class OpenAIKeyClientWrapper:
|
||||
client: AsyncOpenAI
|
||||
key: str
|
||||
remaining_requests: int | None
|
||||
|
||||
def __init__(self, key: str, remaining_requests: int | None) -> None:
|
||||
self.key = key
|
||||
self.remaining_requests = remaining_requests
|
||||
self.updated_at = datetime.utcnow()
|
||||
self.client = AsyncOpenAI(api_key=self.key)
|
||||
|
||||
def update_remaining_requests(self, remaining_requests: int | None) -> None:
|
||||
self.remaining_requests = remaining_requests
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# If remaining_requests is None, then it's the first time we're trying this key
|
||||
# so we can assume it's available, otherwise we check if it's greater than 0
|
||||
if self.remaining_requests is None:
|
||||
return True
|
||||
|
||||
if self.remaining_requests > 0:
|
||||
return True
|
||||
|
||||
# If we haven't checked this in over 1 minutes, check it again
|
||||
# Most of our failures are because of Tokens-per-minute (TPM) limits
|
||||
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class OpenAIClientManager:
|
||||
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
|
||||
clients: list[OpenAIKeyClientWrapper]
|
||||
|
||||
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
|
||||
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
|
||||
|
||||
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
|
||||
available_clients = [client for client in self.clients if client.is_available()]
|
||||
|
||||
if not available_clients:
|
||||
return None
|
||||
|
||||
# Randomly select an available client to distribute requests across our accounts
|
||||
return random.choice(available_clients)
|
||||
|
||||
async def content_builder(
|
||||
self,
|
||||
step: Step,
|
||||
screenshots: list[bytes] | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
content: list[dict[str, Any]] = []
|
||||
|
||||
if prompt is not None:
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{encoded_image}",
|
||||
},
|
||||
}
|
||||
)
|
||||
# create artifact for each image
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
step: Step,
|
||||
model: str = "gpt-4-vision-preview",
|
||||
max_tokens: int = 4096,
|
||||
temperature: int = 0,
|
||||
screenshots: list[bytes] | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
LOG.info(
|
||||
f"Sending LLM request",
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
num_screenshots=len(screenshots) if screenshots else 0,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": await self.content_builder(
|
||||
step=step,
|
||||
screenshots=screenshots,
|
||||
prompt=prompt,
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
chat_completion_kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
|
||||
)
|
||||
available_client = self.get_available_client()
|
||||
if available_client is None:
|
||||
raise NoAvailableOpenAIClients()
|
||||
try:
|
||||
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
|
||||
except openai.RateLimitError as e:
|
||||
# If we get a RateLimitError, we can assume the key is not available anymore
|
||||
if e.code == 429:
|
||||
raise OpenAIRequestTooBigError(e.message)
|
||||
LOG.warning(
|
||||
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
|
||||
)
|
||||
available_client.update_remaining_requests(remaining_requests=0)
|
||||
available_client = self.get_available_client()
|
||||
if available_client is None:
|
||||
raise NoAvailableOpenAIClients()
|
||||
return await self.chat_completion(
|
||||
step=step,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
screenshots=screenshots,
|
||||
prompt=prompt,
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
LOG.error("OpenAI error", exc_info=True)
|
||||
raise e
|
||||
except Exception as e:
|
||||
LOG.error("Unknown error for chat completion", error_message=str(e), error_type=type(e))
|
||||
raise e
|
||||
|
||||
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
|
||||
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
|
||||
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
|
||||
|
||||
# If we get a response, we can assume the key is available and update the remaining requests
|
||||
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
|
||||
|
||||
if not ratelimit_remaining_requests:
|
||||
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
|
||||
|
||||
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
|
||||
chat_completion = response.parse()
|
||||
|
||||
if chat_completion.usage is not None:
|
||||
# TODO (Suchintan): Is this bad design?
|
||||
step = await app.DATABASE.update_step(
|
||||
step_id=step.step_id,
|
||||
task_id=step.task_id,
|
||||
organization_id=step.organization_id,
|
||||
chat_completion_price=ChatCompletionPrice(
|
||||
input_token_count=chat_completion.usage.prompt_tokens,
|
||||
output_token_count=chat_completion.usage.completion_tokens,
|
||||
model_name=model,
|
||||
),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
parsed_response = self.parse_response(chat_completion)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
if not content:
|
||||
raise Exception("openai response content is empty")
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidOpenAIResponseFormat(str(response)) from e
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from skyvern.exceptions import WorkflowParameterNotFound
|
||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||
@@ -264,7 +263,7 @@ class AgentDB:
|
||||
is_last: bool | None = None,
|
||||
retry_index: int | None = None,
|
||||
organization_id: str | None = None,
|
||||
chat_completion_price: ChatCompletionPrice | None = None,
|
||||
incremental_cost: float | None = None,
|
||||
) -> Step:
|
||||
try:
|
||||
with self.Session() as session:
|
||||
@@ -283,18 +282,8 @@ class AgentDB:
|
||||
step.is_last = is_last
|
||||
if retry_index is not None:
|
||||
step.retry_index = retry_index
|
||||
if chat_completion_price is not None:
|
||||
if step.input_token_count is None:
|
||||
step.input_token_count = 0
|
||||
|
||||
if step.output_token_count is None:
|
||||
step.output_token_count = 0
|
||||
|
||||
step.input_token_count += chat_completion_price.input_token_count
|
||||
step.output_token_count += chat_completion_price.output_token_count
|
||||
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
|
||||
step.input_token_count, step.output_token_count
|
||||
)
|
||||
if incremental_cost is not None:
|
||||
step.step_cost = incremental_cost + float(step.step_cost or 0)
|
||||
|
||||
session.commit()
|
||||
updated_step = await self.get_step(task_id, step_id, organization_id)
|
||||
|
||||
@@ -573,9 +573,9 @@ async def extract_information_for_navigation_goal(
|
||||
error_code_mapping_str=json.dumps(task.error_code_mapping) if task.error_code_mapping else None,
|
||||
)
|
||||
|
||||
json_response = await app.OPENAI_CLIENT.chat_completion(
|
||||
step=step,
|
||||
json_response = await app.LLM_API_HANDLER(
|
||||
prompt=extract_information_prompt,
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user