Implement LLM router (#95)

This commit is contained in:
Kerem Yilmaz
2024-03-16 23:13:18 -07:00
committed by GitHub
parent 0e34bfa2bd
commit d1de19556e
16 changed files with 485 additions and 308 deletions

View File

@@ -1,25 +1,54 @@
# Environment that the agent will run in.
ENV=local
# Your OpenAI API Keys. Separate multiple keys with commas. Keys will be used in order until the rate limit is reached for all keys
OPENAI_API_KEYS=["abc","def","ghi"]
# can be either "chromium-headless" or "chromium-headful".
# LLM Provider Configurations:
# ENABLE_OPENAI: Set to true to enable OpenAI as a language model provider.
ENABLE_OPENAI=false
# OPENAI_API_KEY: Your OpenAI API key for accessing models like GPT-4.
OPENAI_API_KEY=""
# ENABLE_ANTHROPIC: Set to true to enable Anthropic as a language model provider.
ENABLE_ANTHROPIC=false
# ANTHROPIC_API_KEY: Your Anthropic API key for accessing models like Claude-3.
ANTHROPIC_API_KEY=""
# ENABLE_AZURE: Set to true to enable Azure as a language model provider.
ENABLE_AZURE=false
# AZURE_DEPLOYMENT: Your Azure deployment name for accessing specific models.
AZURE_DEPLOYMENT=""
# AZURE_API_KEY: Your API key for accessing Azure's language models.
AZURE_API_KEY=""
# AZURE_API_BASE: The base URL for Azure's API.
AZURE_API_BASE=""
# AZURE_API_VERSION: The version of Azure's API to use.
AZURE_API_VERSION=""
# LLM_MODEL: The chosen language model to use. This should be one of the models
# provided by the enabled LLM providers (e.g., OPENAI_GPT4_TURBO, OPENAI_GPT4V, ANTHROPIC_CLAUDE3, AZURE_OPENAI_GPT4V).
LLM_MODEL=""
# Web browser configuration for scraping:
# BROWSER_TYPE: Can be either "chromium-headless" or "chromium-headful".
BROWSER_TYPE="chromium-headful"
# number of times to retry scraping a page before giving up, currently set to 0
# MAX_SCRAPING_RETRIES: Number of times to retry scraping a page before giving up, currently set to 0.
MAX_SCRAPING_RETRIES=0
# path to the directory where videos will be saved
# VIDEO_PATH: Path to the directory where videos will be saved.
VIDEO_PATH=./videos
# timeout for browser actions in milliseconds
# BROWSER_ACTION_TIMEOUT_MS: Timeout for browser actions in milliseconds.
BROWSER_ACTION_TIMEOUT_MS=5000
# maximum number of steps to execute per run unless the agent finishes with a terminal state (last step or error)
MAX_STEPS_PER_RUN = 50
# Control log level
# Agent run configuration:
# MAX_STEPS_PER_RUN: Maximum number of steps to execute per run unless the agent finishes with a terminal state (last step or error).
MAX_STEPS_PER_RUN=50
# Logging and database configuration:
# LOG_LEVEL: Control log level (e.g., INFO, DEBUG).
LOG_LEVEL=INFO
# Database connection string
# DATABASE_STRING: Database connection string.
DATABASE_STRING="postgresql+psycopg://skyvern@localhost/skyvern"
# Port to run the agent on
# PORT: Port to run the agent on.
PORT=8000
# Distinct analytics ID
ANALYTICS_ID="anonymous"
# Analytics configuration:
# Distinct analytics ID (a UUID is generated if left blank).
ANALYTICS_ID="anonymous"

111
setup.sh
View File

@@ -3,7 +3,7 @@
# Call function to send telemetry event
log_event() {
if [ -n $1 ]; then
python skyvern/analytics.py $1
poetry run python skyvern/analytics.py $1
fi
}
@@ -20,28 +20,121 @@ for cmd in poetry python3.11; do
fi
done
# Function to update or add environment variable in .env file
update_or_add_env_var() {
local key=$1
local value=$2
if grep -q "^$key=" .env; then
# Update existing variable
sed -i.bak "s/^$key=.*/$key=$value/" .env && rm -f .env.bak
else
# Add new variable
echo "$key=$value" >> .env
fi
}
# Function to set up LLM provider environment variables
setup_llm_providers() {
echo "Configuring Large Language Model (LLM) Providers..."
echo "Note: All information provided here will be stored only on your local machine."
local model_options=()
# OpenAI Configuration
echo "To enable OpenAI, you must have an OpenAI API key."
read -p "Do you want to enable OpenAI (y/n)? " enable_openai
if [[ "$enable_openai" == "y" ]]; then
read -p "Enter your OpenAI API key: " openai_api_key
if [ -z "$openai_api_key" ]; then
echo "Error: OpenAI API key is required."
echo "OpenAI will not be enabled."
else
update_or_add_env_var "OPENAI_API_KEY" "$openai_api_key"
update_or_add_env_var "ENABLE_OPENAI" "true"
model_options+=("OPENAI_GPT4_TURBO" "OPENAI_GPT4V")
fi
else
update_or_add_env_var "ENABLE_OPENAI" "false"
fi
# Anthropic Configuration
echo "To enable Anthropic, you must have an Anthropic API key."
read -p "Do you want to enable Anthropic (y/n)? " enable_anthropic
if [[ "$enable_anthropic" == "y" ]]; then
read -p "Enter your Anthropic API key: " anthropic_api_key
if [ -z "$anthropic_api_key" ]; then
echo "Error: Anthropic API key is required."
echo "Anthropic will not be enabled."
else
update_or_add_env_var "ANTHROPIC_API_KEY" "$anthropic_api_key"
update_or_add_env_var "ENABLE_ANTHROPIC" "true"
model_options+=("ANTHROPIC_CLAUDE3")
fi
else
update_or_add_env_var "ENABLE_ANTHROPIC" "false"
fi
# Azure Configuration
echo "To enable Azure, you must have an Azure deployment name, API key, base URL, and API version."
read -p "Do you want to enable Azure (y/n)? " enable_azure
if [[ "$enable_azure" == "y" ]]; then
read -p "Enter your Azure deployment name: " azure_deployment
read -p "Enter your Azure API key: " azure_api_key
read -p "Enter your Azure API base URL: " azure_api_base
read -p "Enter your Azure API version: " azure_api_version
if [ -z "$azure_deployment" ] || [ -z "$azure_api_key" ] || [ -z "$azure_api_base" ] || [ -z "$azure_api_version" ]; then
echo "Error: All Azure fields must be populated."
echo "Azure will not be enabled."
else
update_or_add_env_var "AZURE_DEPLOYMENT" "$azure_deployment"
update_or_add_env_var "AZURE_API_KEY" "$azure_api_key"
update_or_add_env_var "AZURE_API_BASE" "$azure_api_base"
update_or_add_env_var "AZURE_API_VERSION" "$azure_api_version"
update_or_add_env_var "ENABLE_AZURE" "true"
model_options+=("AZURE_OPENAI_GPT4V")
fi
else
update_or_add_env_var "ENABLE_AZURE" "false"
fi
# Model Selection
if [ ${#model_options[@]} -eq 0 ]; then
echo "No LLM providers enabled. You won't be able to run Skyvern unless you enable at least one provider. You can re-run this script to enable providers or manually update the .env file."
else
echo "Available LLM models based on your selections:"
for i in "${!model_options[@]}"; do
echo "$((i+1)). ${model_options[$i]}"
done
read -p "Choose a model by number (e.g., 1 for ${model_options[0]}): " model_choice
chosen_model=${model_options[$((model_choice-1))]}
echo "Chosen LLM Model: $chosen_model"
update_or_add_env_var "LLM_MODEL" "$chosen_model"
fi
echo "LLM provider configurations updated in .env."
}
# Function to initialize .env file
initialize_env_file() {
if [ -f ".env" ]; then
echo ".env file already exists, skipping initialization."
read -p "Do you want to go through LLM provider setup again (y/n)? " redo_llm_setup
if [[ "$redo_llm_setup" == "y" ]]; then
setup_llm_providers
fi
return
fi
echo "Initializing .env file..."
cp .env.example .env
# Ask for OpenAI API key
read -p "Please enter your OpenAI API key for GPT4V (this will be stored only in your local .env file): " openai_api_key
awk -v key="$openai_api_key" '{gsub(/OPENAI_API_KEYS=\["abc","def","ghi"\]/, "OPENAI_API_KEYS=[\"" key "\"]"); print}' .env > .env.tmp && mv .env.tmp .env
setup_llm_providers
# Ask for email or generate UUID
read -p "Please enter your email for analytics (press enter to skip): " analytics_id
if [ -z "$analytics_id" ]; then
analytics_id=$(uuidgen)
fi
awk -v id="$analytics_id" '{gsub(/ANALYTICS_ID="anonymous"/, "ANALYTICS_ID=\"" id "\""); print}' .env > .env.tmp && mv .env.tmp .env
update_or_add_env_var "ANALYTICS_ID" "$analytics_id"
echo ".env file has been initialized."
}
@@ -144,7 +237,7 @@ run_alembic_upgrade() {
create_organization() {
echo "Creating organization and API token..."
local org_output api_token
org_output=$(python scripts/create_organization.py Skyvern-Open-Source)
org_output=$(poetry run python scripts/create_organization.py Skyvern-Open-Source)
api_token=$(echo "$org_output" | awk '/token=/{gsub(/.*token='\''|'\''.*/, ""); print}')
# Ensure .streamlit directory exists

View File

@@ -24,7 +24,6 @@ class Settings(BaseSettings):
DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern"
PROMPT_ACTION_HISTORY_WINDOW: int = 5
OPENAI_API_KEYS: list[str] = []
ENV: str = "local"
EXECUTE_ALL_STEPS: bool = True
JSON_LOGGING: bool = False
@@ -48,6 +47,28 @@ class Settings(BaseSettings):
BROWSER_LOCALE: str = "en-US"
BROWSER_TIMEZONE: str = "America/New_York"
#####################
# LLM Configuration #
#####################
# ACTIVE LLM PROVIDER
LLM_KEY: str = "OPENAI_GPT4V"
# COMMON
LLM_CONFIG_MAX_TOKENS: int = 4096
LLM_CONFIG_TEMPERATURE: float = 0
# LLM PROVIDER SPECIFIC
ENABLE_OPENAI: bool = True
ENABLE_ANTHROPIC: bool = False
ENABLE_AZURE: bool = False
# OPENAI
OPENAI_API_KEY: str | None = None
# ANTHROPIC
ANTHROPIC_API_KEY: str | None = None
# AZURE
AZURE_DEPLOYMENT: str | None = None
AZURE_API_KEY: str | None = None
AZURE_API_BASE: str | None = None
AZURE_API_VERSION: str | None = None
def is_cloud_environment(self) -> bool:
"""
:return: True if env is not local, else False

View File

@@ -4,21 +4,11 @@ class SkyvernException(Exception):
super().__init__(message)
class NoAvailableOpenAIClients(SkyvernException):
def __init__(self) -> None:
super().__init__("No available OpenAI API clients found.")
class InvalidOpenAIResponseFormat(SkyvernException):
def __init__(self, message: str | None = None):
super().__init__(f"Invalid response format: {message}")
class OpenAIRequestTooBigError(SkyvernException):
def __init__(self, message: str | None = None):
super().__init__(f"OpenAI request 429 error: {message}")
class FailedToSendWebhook(SkyvernException):
def __init__(self, task_id: str | None = None, workflow_run_id: str | None = None, workflow_id: str | None = None):
workflow_run_str = f"workflow_run_id={workflow_run_id}" if workflow_run_id else ""

View File

@@ -332,9 +332,9 @@ class ForgeAgent(Agent):
json_response = None
actions: list[Action]
if task.navigation_goal:
json_response = await app.OPENAI_CLIENT.chat_completion(
step=step,
json_response = await app.LLM_API_HANDLER(
prompt=extract_action_prompt,
step=step,
screenshots=scraped_page.screenshots,
)
detailed_agent_step_output.llm_response = json_response

View File

@@ -2,7 +2,7 @@ from ddtrace import tracer
from ddtrace.filters import FilterRequestsOnUrl
from skyvern.forge.agent import ForgeAgent
from skyvern.forge.sdk.api.open_ai import OpenAIClientManager
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
from skyvern.forge.sdk.artifact.manager import ArtifactManager
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
from skyvern.forge.sdk.db.client import AgentDB
@@ -28,7 +28,7 @@ DATABASE = AgentDB(
STORAGE = StorageFactory.get_storage()
ARTIFACT_MANAGER = ArtifactManager()
BROWSER_MANAGER = BrowserManager()
OPENAI_CLIENT = OpenAIClientManager()
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
WORKFLOW_SERVICE = WorkflowService()
agent = ForgeAgent()

View File

@@ -1,27 +0,0 @@
from typing import Callable
from pydantic import BaseModel
openai_model_to_price_lambdas = {
"gpt-4-vision-preview": (0.01, 0.03),
"gpt-4-1106-preview": (0.01, 0.03),
"gpt-4-0125-preview": (0.01, 0.03),
"gpt-3.5-turbo": (0.001, 0.002),
"gpt-3.5-turbo-1106": (0.001, 0.002),
"gpt-3.5-turbo-0125": (0.0005, 0.0015),
}
class ChatCompletionPrice(BaseModel):
input_token_count: int
output_token_count: int
openai_model_to_price_lambda: Callable[[int, int], float]
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
super().__init__(
input_token_count=input_token_count,
output_token_count=output_token_count,
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
+ output_token_price * output_token / 1000,
)

View File

View File

@@ -0,0 +1,115 @@
import json
from typing import Any
import litellm
import openai
import structlog
from skyvern.forge import app
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}
@staticmethod
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key)
async def llm_api_handler(
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters()
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
if not llm_config.supports_vision:
screenshots = None
messages = await llm_messages_builder(prompt, screenshots)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)
try:
# TODO (kerem): add a timeout to this call
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
response = await litellm.acompletion(
model=llm_config.model_name,
messages=messages,
**parameters,
)
except openai.OpenAIError as e:
raise LLMProviderError(llm_key) from e
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
)
parsed_response = parse_api_response(response)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
return llm_api_handler
@staticmethod
def get_api_parameters() -> dict[str, Any]:
return {
"max_tokens": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS,
"temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE,
}
@classmethod
def register_custom_handler(cls, llm_key: str, handler: LLMAPIHandler) -> None:
if llm_key in cls._custom_handlers:
raise DuplicateCustomLLMProviderError(llm_key)
cls._custom_handlers[llm_key] = handler

View File

@@ -0,0 +1,70 @@
import structlog
from skyvern.forge.sdk.api.llm.exceptions import (
DuplicateLLMConfigError,
InvalidLLMConfigError,
MissingLLMProviderEnvVarsError,
NoProviderEnabledError,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class LLMConfigRegistry:
_configs: dict[str, LLMConfig] = {}
@staticmethod
def validate_config(llm_key: str, config: LLMConfig) -> None:
missing_env_vars = config.get_missing_env_vars()
if missing_env_vars:
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
@classmethod
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
if llm_key in cls._configs:
raise DuplicateLLMConfigError(llm_key)
cls.validate_config(llm_key, config)
LOG.info("Registering LLM config", llm_key=llm_key)
cls._configs[llm_key] = config
@classmethod
def get_config(cls, llm_key: str) -> LLMConfig:
if llm_key not in cls._configs:
raise InvalidLLMConfigError(llm_key)
return cls._configs[llm_key]
# if none of the LLM providers are enabled, raise an error
if not any(
[
SettingsManager.get_settings().ENABLE_OPENAI,
SettingsManager.get_settings().ENABLE_ANTHROPIC,
SettingsManager.get_settings().ENABLE_AZURE,
]
):
raise NoProviderEnabledError()
if SettingsManager.get_settings().ENABLE_OPENAI:
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo-preview", ["OPENAI_API_KEY"], False))
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-vision-preview", ["OPENAI_API_KEY"], True))
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
)
if SettingsManager.get_settings().ENABLE_AZURE:
LLMConfigRegistry.register_config(
"AZURE_OPENAI_GPT4V",
LLMConfig(
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
True,
),
)

View File

@@ -0,0 +1,48 @@
from skyvern.exceptions import SkyvernException
class BaseLLMError(SkyvernException):
pass
class MissingLLMProviderEnvVarsError(BaseLLMError):
def __init__(self, llm_key: str, missing_env_vars: list[str]) -> None:
super().__init__(f"Environment variables {','.join(missing_env_vars)} are required for LLMProvider {llm_key}")
class EmptyLLMResponseError(BaseLLMError):
def __init__(self, response: str) -> None:
super().__init__(f"LLM response content is empty: {response}")
class InvalidLLMResponseFormat(BaseLLMError):
def __init__(self, response: str) -> None:
super().__init__(f"LLM response content is not a valid JSON: {response}")
class DuplicateCustomLLMProviderError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"Custom LLMProvider {llm_key} is already registered")
class DuplicateLLMConfigError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"LLM config with key {llm_key} is already registered")
class InvalidLLMConfigError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"LLM config with key {llm_key} is not a valid LLMConfig")
class LLMProviderError(BaseLLMError):
def __init__(self, llm_key: str) -> None:
super().__init__(f"Error while using LLMProvider {llm_key}")
class NoProviderEnabledError(BaseLLMError):
def __init__(self) -> None:
super().__init__(
"At least one LLM provider must be enabled. Run setup.sh and follow through the LLM provider setup, or "
"update the .env file (check out .env.example to see the required environment variables)."
)

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import Any, Awaitable, Protocol
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
@dataclass(frozen=True)
class LLMConfig:
model_name: str
required_env_vars: list[str]
supports_vision: bool
def get_missing_env_vars(self) -> list[str]:
missing_env_vars = []
for env_var in self.required_env_vars:
env_var_value = getattr(SettingsManager.get_settings(), env_var, None)
if not env_var_value:
missing_env_vars.append(env_var)
return missing_env_vars
class LLMAPIHandler(Protocol):
def __call__(
self,
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]:
...

View File

@@ -0,0 +1,45 @@
import base64
from typing import Any
import commentjson
import litellm
from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat
async def llm_messages_builder(
prompt: str,
screenshots: list[bytes] | None = None,
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = [
{
"type": "text",
"text": prompt,
}
]
if screenshots:
for screenshot in screenshots:
encoded_image = base64.b64encode(screenshot).decode("utf-8")
messages.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_image}",
},
}
)
return [{"role": "user", "content": messages}]
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
try:
content = response.choices[0].message.content
content = content.replace("```json", "")
content = content.replace("```", "")
if not content:
raise EmptyLLMResponseError(str(response))
return commentjson.loads(content)
except Exception as e:
raise InvalidLLMResponseFormat(str(response)) from e

View File

@@ -1,228 +0,0 @@
import base64
import json
import random
from datetime import datetime, timedelta
from typing import Any
import commentjson
import openai
import structlog
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import ChatCompletion
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
from skyvern.forge import app
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
LOG = structlog.get_logger()
class OpenAIKeyClientWrapper:
client: AsyncOpenAI
key: str
remaining_requests: int | None
def __init__(self, key: str, remaining_requests: int | None) -> None:
self.key = key
self.remaining_requests = remaining_requests
self.updated_at = datetime.utcnow()
self.client = AsyncOpenAI(api_key=self.key)
def update_remaining_requests(self, remaining_requests: int | None) -> None:
self.remaining_requests = remaining_requests
self.updated_at = datetime.utcnow()
def is_available(self) -> bool:
# If remaining_requests is None, then it's the first time we're trying this key
# so we can assume it's available, otherwise we check if it's greater than 0
if self.remaining_requests is None:
return True
if self.remaining_requests > 0:
return True
# If we haven't checked this in over 1 minutes, check it again
# Most of our failures are because of Tokens-per-minute (TPM) limits
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
return True
return False
class OpenAIClientManager:
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
clients: list[OpenAIKeyClientWrapper]
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
available_clients = [client for client in self.clients if client.is_available()]
if not available_clients:
return None
# Randomly select an available client to distribute requests across our accounts
return random.choice(available_clients)
async def content_builder(
self,
step: Step,
screenshots: list[bytes] | None = None,
prompt: str | None = None,
) -> list[dict[str, Any]]:
content: list[dict[str, Any]] = []
if prompt is not None:
content.append(
{
"type": "text",
"text": prompt,
}
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
if screenshots:
for screenshot in screenshots:
encoded_image = base64.b64encode(screenshot).decode("utf-8")
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_image}",
},
}
)
# create artifact for each image
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
return content
async def chat_completion(
self,
step: Step,
model: str = "gpt-4-vision-preview",
max_tokens: int = 4096,
temperature: int = 0,
screenshots: list[bytes] | None = None,
prompt: str | None = None,
) -> dict[str, Any]:
LOG.info(
f"Sending LLM request",
task_id=step.task_id,
step_id=step.step_id,
num_screenshots=len(screenshots) if screenshots else 0,
)
messages = [
{
"role": "user",
"content": await self.content_builder(
step=step,
screenshots=screenshots,
prompt=prompt,
),
}
]
chat_completion_kwargs = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
)
available_client = self.get_available_client()
if available_client is None:
raise NoAvailableOpenAIClients()
try:
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
except openai.RateLimitError as e:
# If we get a RateLimitError, we can assume the key is not available anymore
if e.code == 429:
raise OpenAIRequestTooBigError(e.message)
LOG.warning(
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
)
available_client.update_remaining_requests(remaining_requests=0)
available_client = self.get_available_client()
if available_client is None:
raise NoAvailableOpenAIClients()
return await self.chat_completion(
step=step,
model=model,
max_tokens=max_tokens,
temperature=temperature,
screenshots=screenshots,
prompt=prompt,
)
except openai.OpenAIError as e:
LOG.error("OpenAI error", exc_info=True)
raise e
except Exception as e:
LOG.error("Unknown error for chat completion", error_message=str(e), error_type=type(e))
raise e
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
# If we get a response, we can assume the key is available and update the remaining requests
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
if not ratelimit_remaining_requests:
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
chat_completion = response.parse()
if chat_completion.usage is not None:
# TODO (Suchintan): Is this bad design?
step = await app.DATABASE.update_step(
step_id=step.step_id,
task_id=step.task_id,
organization_id=step.organization_id,
chat_completion_price=ChatCompletionPrice(
input_token_count=chat_completion.usage.prompt_tokens,
output_token_count=chat_completion.usage.completion_tokens,
model_name=model,
),
)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
)
parsed_response = self.parse_response(chat_completion)
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
try:
content = response.choices[0].message.content
content = content.replace("```json", "")
content = content.replace("```", "")
if not content:
raise Exception("openai response content is empty")
return commentjson.loads(content)
except Exception as e:
raise InvalidOpenAIResponseFormat(str(response)) from e

View File

@@ -7,7 +7,6 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.db.exceptions import NotFoundError
@@ -264,7 +263,7 @@ class AgentDB:
is_last: bool | None = None,
retry_index: int | None = None,
organization_id: str | None = None,
chat_completion_price: ChatCompletionPrice | None = None,
incremental_cost: float | None = None,
) -> Step:
try:
with self.Session() as session:
@@ -283,18 +282,8 @@ class AgentDB:
step.is_last = is_last
if retry_index is not None:
step.retry_index = retry_index
if chat_completion_price is not None:
if step.input_token_count is None:
step.input_token_count = 0
if step.output_token_count is None:
step.output_token_count = 0
step.input_token_count += chat_completion_price.input_token_count
step.output_token_count += chat_completion_price.output_token_count
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
step.input_token_count, step.output_token_count
)
if incremental_cost is not None:
step.step_cost = incremental_cost + float(step.step_cost or 0)
session.commit()
updated_step = await self.get_step(task_id, step_id, organization_id)

View File

@@ -573,9 +573,9 @@ async def extract_information_for_navigation_goal(
error_code_mapping_str=json.dumps(task.error_code_mapping) if task.error_code_mapping else None,
)
json_response = await app.OPENAI_CLIENT.chat_completion(
step=step,
json_response = await app.LLM_API_HANDLER(
prompt=extract_information_prompt,
step=step,
screenshots=scraped_page.screenshots,
)