Implement LLM router (#95)
This commit is contained in:
55
.env.example
55
.env.example
@@ -1,25 +1,54 @@
|
|||||||
# Environment that the agent will run in.
|
# Environment that the agent will run in.
|
||||||
ENV=local
|
ENV=local
|
||||||
# Your OpenAI API Keys. Separate multiple keys with commas. Keys will be used in order until the rate limit is reached for all keys
|
|
||||||
OPENAI_API_KEYS=["abc","def","ghi"]
|
|
||||||
|
|
||||||
# can be either "chromium-headless" or "chromium-headful".
|
# LLM Provider Configurations:
|
||||||
|
# ENABLE_OPENAI: Set to true to enable OpenAI as a language model provider.
|
||||||
|
ENABLE_OPENAI=false
|
||||||
|
# OPENAI_API_KEY: Your OpenAI API key for accessing models like GPT-4.
|
||||||
|
OPENAI_API_KEY=""
|
||||||
|
|
||||||
|
# ENABLE_ANTHROPIC: Set to true to enable Anthropic as a language model provider.
|
||||||
|
ENABLE_ANTHROPIC=false
|
||||||
|
# ANTHROPIC_API_KEY: Your Anthropic API key for accessing models like Claude-3.
|
||||||
|
ANTHROPIC_API_KEY=""
|
||||||
|
|
||||||
|
# ENABLE_AZURE: Set to true to enable Azure as a language model provider.
|
||||||
|
ENABLE_AZURE=false
|
||||||
|
# AZURE_DEPLOYMENT: Your Azure deployment name for accessing specific models.
|
||||||
|
AZURE_DEPLOYMENT=""
|
||||||
|
# AZURE_API_KEY: Your API key for accessing Azure's language models.
|
||||||
|
AZURE_API_KEY=""
|
||||||
|
# AZURE_API_BASE: The base URL for Azure's API.
|
||||||
|
AZURE_API_BASE=""
|
||||||
|
# AZURE_API_VERSION: The version of Azure's API to use.
|
||||||
|
AZURE_API_VERSION=""
|
||||||
|
|
||||||
|
# LLM_MODEL: The chosen language model to use. This should be one of the models
|
||||||
|
# provided by the enabled LLM providers (e.g., OPENAI_GPT4_TURBO, OPENAI_GPT4V, ANTHROPIC_CLAUDE3, AZURE_OPENAI_GPT4V).
|
||||||
|
LLM_MODEL=""
|
||||||
|
|
||||||
|
# Web browser configuration for scraping:
|
||||||
|
# BROWSER_TYPE: Can be either "chromium-headless" or "chromium-headful".
|
||||||
BROWSER_TYPE="chromium-headful"
|
BROWSER_TYPE="chromium-headful"
|
||||||
# number of times to retry scraping a page before giving up, currently set to 0
|
# MAX_SCRAPING_RETRIES: Number of times to retry scraping a page before giving up, currently set to 0.
|
||||||
MAX_SCRAPING_RETRIES=0
|
MAX_SCRAPING_RETRIES=0
|
||||||
# path to the directory where videos will be saved
|
# VIDEO_PATH: Path to the directory where videos will be saved.
|
||||||
VIDEO_PATH=./videos
|
VIDEO_PATH=./videos
|
||||||
# timeout for browser actions in milliseconds
|
# BROWSER_ACTION_TIMEOUT_MS: Timeout for browser actions in milliseconds.
|
||||||
BROWSER_ACTION_TIMEOUT_MS=5000
|
BROWSER_ACTION_TIMEOUT_MS=5000
|
||||||
# maximum number of steps to execute per run unless the agent finishes with a terminal state (last step or error)
|
|
||||||
MAX_STEPS_PER_RUN = 50
|
|
||||||
|
|
||||||
# Control log level
|
# Agent run configuration:
|
||||||
|
# MAX_STEPS_PER_RUN: Maximum number of steps to execute per run unless the agent finishes with a terminal state (last step or error).
|
||||||
|
MAX_STEPS_PER_RUN=50
|
||||||
|
|
||||||
|
# Logging and database configuration:
|
||||||
|
# LOG_LEVEL: Control log level (e.g., INFO, DEBUG).
|
||||||
LOG_LEVEL=INFO
|
LOG_LEVEL=INFO
|
||||||
# Database connection string
|
# DATABASE_STRING: Database connection string.
|
||||||
DATABASE_STRING="postgresql+psycopg://skyvern@localhost/skyvern"
|
DATABASE_STRING="postgresql+psycopg://skyvern@localhost/skyvern"
|
||||||
# Port to run the agent on
|
# PORT: Port to run the agent on.
|
||||||
PORT=8000
|
PORT=8000
|
||||||
|
|
||||||
# Distinct analytics ID
|
# Analytics configuration:
|
||||||
ANALYTICS_ID="anonymous"
|
# Distinct analytics ID (a UUID is generated if left blank).
|
||||||
|
ANALYTICS_ID="anonymous"
|
||||||
|
|||||||
111
setup.sh
111
setup.sh
@@ -3,7 +3,7 @@
|
|||||||
# Call function to send telemetry event
|
# Call function to send telemetry event
|
||||||
log_event() {
|
log_event() {
|
||||||
if [ -n $1 ]; then
|
if [ -n $1 ]; then
|
||||||
python skyvern/analytics.py $1
|
poetry run python skyvern/analytics.py $1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -20,28 +20,121 @@ for cmd in poetry python3.11; do
|
|||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# Function to update or add environment variable in .env file
|
||||||
|
update_or_add_env_var() {
|
||||||
|
local key=$1
|
||||||
|
local value=$2
|
||||||
|
if grep -q "^$key=" .env; then
|
||||||
|
# Update existing variable
|
||||||
|
sed -i.bak "s/^$key=.*/$key=$value/" .env && rm -f .env.bak
|
||||||
|
else
|
||||||
|
# Add new variable
|
||||||
|
echo "$key=$value" >> .env
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to set up LLM provider environment variables
|
||||||
|
setup_llm_providers() {
|
||||||
|
echo "Configuring Large Language Model (LLM) Providers..."
|
||||||
|
echo "Note: All information provided here will be stored only on your local machine."
|
||||||
|
local model_options=()
|
||||||
|
|
||||||
|
# OpenAI Configuration
|
||||||
|
echo "To enable OpenAI, you must have an OpenAI API key."
|
||||||
|
read -p "Do you want to enable OpenAI (y/n)? " enable_openai
|
||||||
|
if [[ "$enable_openai" == "y" ]]; then
|
||||||
|
read -p "Enter your OpenAI API key: " openai_api_key
|
||||||
|
if [ -z "$openai_api_key" ]; then
|
||||||
|
echo "Error: OpenAI API key is required."
|
||||||
|
echo "OpenAI will not be enabled."
|
||||||
|
else
|
||||||
|
update_or_add_env_var "OPENAI_API_KEY" "$openai_api_key"
|
||||||
|
update_or_add_env_var "ENABLE_OPENAI" "true"
|
||||||
|
model_options+=("OPENAI_GPT4_TURBO" "OPENAI_GPT4V")
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
update_or_add_env_var "ENABLE_OPENAI" "false"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Anthropic Configuration
|
||||||
|
echo "To enable Anthropic, you must have an Anthropic API key."
|
||||||
|
read -p "Do you want to enable Anthropic (y/n)? " enable_anthropic
|
||||||
|
if [[ "$enable_anthropic" == "y" ]]; then
|
||||||
|
read -p "Enter your Anthropic API key: " anthropic_api_key
|
||||||
|
if [ -z "$anthropic_api_key" ]; then
|
||||||
|
echo "Error: Anthropic API key is required."
|
||||||
|
echo "Anthropic will not be enabled."
|
||||||
|
else
|
||||||
|
update_or_add_env_var "ANTHROPIC_API_KEY" "$anthropic_api_key"
|
||||||
|
update_or_add_env_var "ENABLE_ANTHROPIC" "true"
|
||||||
|
model_options+=("ANTHROPIC_CLAUDE3")
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
update_or_add_env_var "ENABLE_ANTHROPIC" "false"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Azure Configuration
|
||||||
|
echo "To enable Azure, you must have an Azure deployment name, API key, base URL, and API version."
|
||||||
|
read -p "Do you want to enable Azure (y/n)? " enable_azure
|
||||||
|
if [[ "$enable_azure" == "y" ]]; then
|
||||||
|
read -p "Enter your Azure deployment name: " azure_deployment
|
||||||
|
read -p "Enter your Azure API key: " azure_api_key
|
||||||
|
read -p "Enter your Azure API base URL: " azure_api_base
|
||||||
|
read -p "Enter your Azure API version: " azure_api_version
|
||||||
|
if [ -z "$azure_deployment" ] || [ -z "$azure_api_key" ] || [ -z "$azure_api_base" ] || [ -z "$azure_api_version" ]; then
|
||||||
|
echo "Error: All Azure fields must be populated."
|
||||||
|
echo "Azure will not be enabled."
|
||||||
|
else
|
||||||
|
update_or_add_env_var "AZURE_DEPLOYMENT" "$azure_deployment"
|
||||||
|
update_or_add_env_var "AZURE_API_KEY" "$azure_api_key"
|
||||||
|
update_or_add_env_var "AZURE_API_BASE" "$azure_api_base"
|
||||||
|
update_or_add_env_var "AZURE_API_VERSION" "$azure_api_version"
|
||||||
|
update_or_add_env_var "ENABLE_AZURE" "true"
|
||||||
|
model_options+=("AZURE_OPENAI_GPT4V")
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
update_or_add_env_var "ENABLE_AZURE" "false"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Model Selection
|
||||||
|
if [ ${#model_options[@]} -eq 0 ]; then
|
||||||
|
echo "No LLM providers enabled. You won't be able to run Skyvern unless you enable at least one provider. You can re-run this script to enable providers or manually update the .env file."
|
||||||
|
else
|
||||||
|
echo "Available LLM models based on your selections:"
|
||||||
|
for i in "${!model_options[@]}"; do
|
||||||
|
echo "$((i+1)). ${model_options[$i]}"
|
||||||
|
done
|
||||||
|
read -p "Choose a model by number (e.g., 1 for ${model_options[0]}): " model_choice
|
||||||
|
chosen_model=${model_options[$((model_choice-1))]}
|
||||||
|
echo "Chosen LLM Model: $chosen_model"
|
||||||
|
update_or_add_env_var "LLM_MODEL" "$chosen_model"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "LLM provider configurations updated in .env."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Function to initialize .env file
|
# Function to initialize .env file
|
||||||
initialize_env_file() {
|
initialize_env_file() {
|
||||||
if [ -f ".env" ]; then
|
if [ -f ".env" ]; then
|
||||||
echo ".env file already exists, skipping initialization."
|
echo ".env file already exists, skipping initialization."
|
||||||
|
read -p "Do you want to go through LLM provider setup again (y/n)? " redo_llm_setup
|
||||||
|
if [[ "$redo_llm_setup" == "y" ]]; then
|
||||||
|
setup_llm_providers
|
||||||
|
fi
|
||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Initializing .env file..."
|
echo "Initializing .env file..."
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
|
setup_llm_providers
|
||||||
# Ask for OpenAI API key
|
|
||||||
read -p "Please enter your OpenAI API key for GPT4V (this will be stored only in your local .env file): " openai_api_key
|
|
||||||
awk -v key="$openai_api_key" '{gsub(/OPENAI_API_KEYS=\["abc","def","ghi"\]/, "OPENAI_API_KEYS=[\"" key "\"]"); print}' .env > .env.tmp && mv .env.tmp .env
|
|
||||||
|
|
||||||
|
|
||||||
# Ask for email or generate UUID
|
# Ask for email or generate UUID
|
||||||
read -p "Please enter your email for analytics (press enter to skip): " analytics_id
|
read -p "Please enter your email for analytics (press enter to skip): " analytics_id
|
||||||
if [ -z "$analytics_id" ]; then
|
if [ -z "$analytics_id" ]; then
|
||||||
analytics_id=$(uuidgen)
|
analytics_id=$(uuidgen)
|
||||||
fi
|
fi
|
||||||
awk -v id="$analytics_id" '{gsub(/ANALYTICS_ID="anonymous"/, "ANALYTICS_ID=\"" id "\""); print}' .env > .env.tmp && mv .env.tmp .env
|
update_or_add_env_var "ANALYTICS_ID" "$analytics_id"
|
||||||
|
|
||||||
echo ".env file has been initialized."
|
echo ".env file has been initialized."
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +237,7 @@ run_alembic_upgrade() {
|
|||||||
create_organization() {
|
create_organization() {
|
||||||
echo "Creating organization and API token..."
|
echo "Creating organization and API token..."
|
||||||
local org_output api_token
|
local org_output api_token
|
||||||
org_output=$(python scripts/create_organization.py Skyvern-Open-Source)
|
org_output=$(poetry run python scripts/create_organization.py Skyvern-Open-Source)
|
||||||
api_token=$(echo "$org_output" | awk '/token=/{gsub(/.*token='\''|'\''.*/, ""); print}')
|
api_token=$(echo "$org_output" | awk '/token=/{gsub(/.*token='\''|'\''.*/, ""); print}')
|
||||||
|
|
||||||
# Ensure .streamlit directory exists
|
# Ensure .streamlit directory exists
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ class Settings(BaseSettings):
|
|||||||
DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern"
|
DATABASE_STRING: str = "postgresql+psycopg://skyvern@localhost/skyvern"
|
||||||
PROMPT_ACTION_HISTORY_WINDOW: int = 5
|
PROMPT_ACTION_HISTORY_WINDOW: int = 5
|
||||||
|
|
||||||
OPENAI_API_KEYS: list[str] = []
|
|
||||||
ENV: str = "local"
|
ENV: str = "local"
|
||||||
EXECUTE_ALL_STEPS: bool = True
|
EXECUTE_ALL_STEPS: bool = True
|
||||||
JSON_LOGGING: bool = False
|
JSON_LOGGING: bool = False
|
||||||
@@ -48,6 +47,28 @@ class Settings(BaseSettings):
|
|||||||
BROWSER_LOCALE: str = "en-US"
|
BROWSER_LOCALE: str = "en-US"
|
||||||
BROWSER_TIMEZONE: str = "America/New_York"
|
BROWSER_TIMEZONE: str = "America/New_York"
|
||||||
|
|
||||||
|
#####################
|
||||||
|
# LLM Configuration #
|
||||||
|
#####################
|
||||||
|
# ACTIVE LLM PROVIDER
|
||||||
|
LLM_KEY: str = "OPENAI_GPT4V"
|
||||||
|
# COMMON
|
||||||
|
LLM_CONFIG_MAX_TOKENS: int = 4096
|
||||||
|
LLM_CONFIG_TEMPERATURE: float = 0
|
||||||
|
# LLM PROVIDER SPECIFIC
|
||||||
|
ENABLE_OPENAI: bool = True
|
||||||
|
ENABLE_ANTHROPIC: bool = False
|
||||||
|
ENABLE_AZURE: bool = False
|
||||||
|
# OPENAI
|
||||||
|
OPENAI_API_KEY: str | None = None
|
||||||
|
# ANTHROPIC
|
||||||
|
ANTHROPIC_API_KEY: str | None = None
|
||||||
|
# AZURE
|
||||||
|
AZURE_DEPLOYMENT: str | None = None
|
||||||
|
AZURE_API_KEY: str | None = None
|
||||||
|
AZURE_API_BASE: str | None = None
|
||||||
|
AZURE_API_VERSION: str | None = None
|
||||||
|
|
||||||
def is_cloud_environment(self) -> bool:
|
def is_cloud_environment(self) -> bool:
|
||||||
"""
|
"""
|
||||||
:return: True if env is not local, else False
|
:return: True if env is not local, else False
|
||||||
|
|||||||
@@ -4,21 +4,11 @@ class SkyvernException(Exception):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class NoAvailableOpenAIClients(SkyvernException):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__("No available OpenAI API clients found.")
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidOpenAIResponseFormat(SkyvernException):
|
class InvalidOpenAIResponseFormat(SkyvernException):
|
||||||
def __init__(self, message: str | None = None):
|
def __init__(self, message: str | None = None):
|
||||||
super().__init__(f"Invalid response format: {message}")
|
super().__init__(f"Invalid response format: {message}")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIRequestTooBigError(SkyvernException):
|
|
||||||
def __init__(self, message: str | None = None):
|
|
||||||
super().__init__(f"OpenAI request 429 error: {message}")
|
|
||||||
|
|
||||||
|
|
||||||
class FailedToSendWebhook(SkyvernException):
|
class FailedToSendWebhook(SkyvernException):
|
||||||
def __init__(self, task_id: str | None = None, workflow_run_id: str | None = None, workflow_id: str | None = None):
|
def __init__(self, task_id: str | None = None, workflow_run_id: str | None = None, workflow_id: str | None = None):
|
||||||
workflow_run_str = f"workflow_run_id={workflow_run_id}" if workflow_run_id else ""
|
workflow_run_str = f"workflow_run_id={workflow_run_id}" if workflow_run_id else ""
|
||||||
|
|||||||
@@ -332,9 +332,9 @@ class ForgeAgent(Agent):
|
|||||||
json_response = None
|
json_response = None
|
||||||
actions: list[Action]
|
actions: list[Action]
|
||||||
if task.navigation_goal:
|
if task.navigation_goal:
|
||||||
json_response = await app.OPENAI_CLIENT.chat_completion(
|
json_response = await app.LLM_API_HANDLER(
|
||||||
step=step,
|
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
)
|
)
|
||||||
detailed_agent_step_output.llm_response = json_response
|
detailed_agent_step_output.llm_response = json_response
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from ddtrace import tracer
|
|||||||
from ddtrace.filters import FilterRequestsOnUrl
|
from ddtrace.filters import FilterRequestsOnUrl
|
||||||
|
|
||||||
from skyvern.forge.agent import ForgeAgent
|
from skyvern.forge.agent import ForgeAgent
|
||||||
from skyvern.forge.sdk.api.open_ai import OpenAIClientManager
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMAPIHandlerFactory
|
||||||
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
from skyvern.forge.sdk.artifact.manager import ArtifactManager
|
||||||
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
from skyvern.forge.sdk.artifact.storage.factory import StorageFactory
|
||||||
from skyvern.forge.sdk.db.client import AgentDB
|
from skyvern.forge.sdk.db.client import AgentDB
|
||||||
@@ -28,7 +28,7 @@ DATABASE = AgentDB(
|
|||||||
STORAGE = StorageFactory.get_storage()
|
STORAGE = StorageFactory.get_storage()
|
||||||
ARTIFACT_MANAGER = ArtifactManager()
|
ARTIFACT_MANAGER = ArtifactManager()
|
||||||
BROWSER_MANAGER = BrowserManager()
|
BROWSER_MANAGER = BrowserManager()
|
||||||
OPENAI_CLIENT = OpenAIClientManager()
|
LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(SettingsManager.get_settings().LLM_KEY)
|
||||||
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
WORKFLOW_CONTEXT_MANAGER = WorkflowContextManager()
|
||||||
WORKFLOW_SERVICE = WorkflowService()
|
WORKFLOW_SERVICE = WorkflowService()
|
||||||
agent = ForgeAgent()
|
agent = ForgeAgent()
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
from typing import Callable
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
openai_model_to_price_lambdas = {
|
|
||||||
"gpt-4-vision-preview": (0.01, 0.03),
|
|
||||||
"gpt-4-1106-preview": (0.01, 0.03),
|
|
||||||
"gpt-4-0125-preview": (0.01, 0.03),
|
|
||||||
"gpt-3.5-turbo": (0.001, 0.002),
|
|
||||||
"gpt-3.5-turbo-1106": (0.001, 0.002),
|
|
||||||
"gpt-3.5-turbo-0125": (0.0005, 0.0015),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionPrice(BaseModel):
|
|
||||||
input_token_count: int
|
|
||||||
output_token_count: int
|
|
||||||
openai_model_to_price_lambda: Callable[[int, int], float]
|
|
||||||
|
|
||||||
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
|
|
||||||
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
|
|
||||||
super().__init__(
|
|
||||||
input_token_count=input_token_count,
|
|
||||||
output_token_count=output_token_count,
|
|
||||||
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
|
|
||||||
+ output_token_price * output_token / 1000,
|
|
||||||
)
|
|
||||||
0
skyvern/forge/sdk/api/llm/__init__.py
Normal file
0
skyvern/forge/sdk/api/llm/__init__.py
Normal file
115
skyvern/forge/sdk/api/llm/api_handler_factory.py
Normal file
115
skyvern/forge/sdk/api/llm/api_handler_factory.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
import openai
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||||
|
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
|
||||||
|
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
||||||
|
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
||||||
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
|
from skyvern.forge.sdk.models import Step
|
||||||
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class LLMAPIHandlerFactory:
|
||||||
|
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
|
||||||
|
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||||
|
|
||||||
|
async def llm_api_handler(
|
||||||
|
prompt: str,
|
||||||
|
step: Step | None = None,
|
||||||
|
screenshots: list[bytes] | None = None,
|
||||||
|
parameters: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if parameters is None:
|
||||||
|
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
||||||
|
|
||||||
|
if step:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.LLM_PROMPT,
|
||||||
|
data=prompt.encode("utf-8"),
|
||||||
|
)
|
||||||
|
for screenshot in screenshots or []:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||||
|
data=screenshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
|
||||||
|
if not llm_config.supports_vision:
|
||||||
|
screenshots = None
|
||||||
|
|
||||||
|
messages = await llm_messages_builder(prompt, screenshots)
|
||||||
|
if step:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.LLM_REQUEST,
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": llm_config.model_name,
|
||||||
|
"messages": messages,
|
||||||
|
**parameters,
|
||||||
|
}
|
||||||
|
).encode("utf-8"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# TODO (kerem): add a timeout to this call
|
||||||
|
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
|
||||||
|
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=llm_config.model_name,
|
||||||
|
messages=messages,
|
||||||
|
**parameters,
|
||||||
|
)
|
||||||
|
except openai.OpenAIError as e:
|
||||||
|
raise LLMProviderError(llm_key) from e
|
||||||
|
except Exception as e:
|
||||||
|
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
|
||||||
|
raise LLMProviderError(llm_key) from e
|
||||||
|
if step:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||||
|
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||||
|
)
|
||||||
|
llm_cost = litellm.completion_cost(completion_response=response)
|
||||||
|
await app.DATABASE.update_step(
|
||||||
|
task_id=step.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
organization_id=step.organization_id,
|
||||||
|
incremental_cost=llm_cost,
|
||||||
|
)
|
||||||
|
parsed_response = parse_api_response(response)
|
||||||
|
if step:
|
||||||
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
|
step=step,
|
||||||
|
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||||
|
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||||
|
)
|
||||||
|
return parsed_response
|
||||||
|
|
||||||
|
return llm_api_handler
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_api_parameters() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"max_tokens": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS,
|
||||||
|
"temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_custom_handler(cls, llm_key: str, handler: LLMAPIHandler) -> None:
|
||||||
|
if llm_key in cls._custom_handlers:
|
||||||
|
raise DuplicateCustomLLMProviderError(llm_key)
|
||||||
|
cls._custom_handlers[llm_key] = handler
|
||||||
70
skyvern/forge/sdk/api/llm/config_registry.py
Normal file
70
skyvern/forge/sdk/api/llm/config_registry.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import structlog
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.api.llm.exceptions import (
|
||||||
|
DuplicateLLMConfigError,
|
||||||
|
InvalidLLMConfigError,
|
||||||
|
MissingLLMProviderEnvVarsError,
|
||||||
|
NoProviderEnabledError,
|
||||||
|
)
|
||||||
|
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
||||||
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigRegistry:
|
||||||
|
_configs: dict[str, LLMConfig] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_config(llm_key: str, config: LLMConfig) -> None:
|
||||||
|
missing_env_vars = config.get_missing_env_vars()
|
||||||
|
if missing_env_vars:
|
||||||
|
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
|
||||||
|
if llm_key in cls._configs:
|
||||||
|
raise DuplicateLLMConfigError(llm_key)
|
||||||
|
|
||||||
|
cls.validate_config(llm_key, config)
|
||||||
|
|
||||||
|
LOG.info("Registering LLM config", llm_key=llm_key)
|
||||||
|
cls._configs[llm_key] = config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls, llm_key: str) -> LLMConfig:
|
||||||
|
if llm_key not in cls._configs:
|
||||||
|
raise InvalidLLMConfigError(llm_key)
|
||||||
|
|
||||||
|
return cls._configs[llm_key]
|
||||||
|
|
||||||
|
|
||||||
|
# if none of the LLM providers are enabled, raise an error
|
||||||
|
if not any(
|
||||||
|
[
|
||||||
|
SettingsManager.get_settings().ENABLE_OPENAI,
|
||||||
|
SettingsManager.get_settings().ENABLE_ANTHROPIC,
|
||||||
|
SettingsManager.get_settings().ENABLE_AZURE,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise NoProviderEnabledError()
|
||||||
|
|
||||||
|
|
||||||
|
if SettingsManager.get_settings().ENABLE_OPENAI:
|
||||||
|
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo-preview", ["OPENAI_API_KEY"], False))
|
||||||
|
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-vision-preview", ["OPENAI_API_KEY"], True))
|
||||||
|
|
||||||
|
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
|
||||||
|
LLMConfigRegistry.register_config(
|
||||||
|
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if SettingsManager.get_settings().ENABLE_AZURE:
|
||||||
|
LLMConfigRegistry.register_config(
|
||||||
|
"AZURE_OPENAI_GPT4V",
|
||||||
|
LLMConfig(
|
||||||
|
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
|
||||||
|
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
)
|
||||||
48
skyvern/forge/sdk/api/llm/exceptions.py
Normal file
48
skyvern/forge/sdk/api/llm/exceptions.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from skyvern.exceptions import SkyvernException
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMError(SkyvernException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MissingLLMProviderEnvVarsError(BaseLLMError):
|
||||||
|
def __init__(self, llm_key: str, missing_env_vars: list[str]) -> None:
|
||||||
|
super().__init__(f"Environment variables {','.join(missing_env_vars)} are required for LLMProvider {llm_key}")
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyLLMResponseError(BaseLLMError):
|
||||||
|
def __init__(self, response: str) -> None:
|
||||||
|
super().__init__(f"LLM response content is empty: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidLLMResponseFormat(BaseLLMError):
|
||||||
|
def __init__(self, response: str) -> None:
|
||||||
|
super().__init__(f"LLM response content is not a valid JSON: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateCustomLLMProviderError(BaseLLMError):
|
||||||
|
def __init__(self, llm_key: str) -> None:
|
||||||
|
super().__init__(f"Custom LLMProvider {llm_key} is already registered")
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateLLMConfigError(BaseLLMError):
|
||||||
|
def __init__(self, llm_key: str) -> None:
|
||||||
|
super().__init__(f"LLM config with key {llm_key} is already registered")
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidLLMConfigError(BaseLLMError):
|
||||||
|
def __init__(self, llm_key: str) -> None:
|
||||||
|
super().__init__(f"LLM config with key {llm_key} is not a valid LLMConfig")
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProviderError(BaseLLMError):
|
||||||
|
def __init__(self, llm_key: str) -> None:
|
||||||
|
super().__init__(f"Error while using LLMProvider {llm_key}")
|
||||||
|
|
||||||
|
|
||||||
|
class NoProviderEnabledError(BaseLLMError):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(
|
||||||
|
"At least one LLM provider must be enabled. Run setup.sh and follow through the LLM provider setup, or "
|
||||||
|
"update the .env file (check out .env.example to see the required environment variables)."
|
||||||
|
)
|
||||||
32
skyvern/forge/sdk/api/llm/models.py
Normal file
32
skyvern/forge/sdk/api/llm/models.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Awaitable, Protocol
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.models import Step
|
||||||
|
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LLMConfig:
|
||||||
|
model_name: str
|
||||||
|
required_env_vars: list[str]
|
||||||
|
supports_vision: bool
|
||||||
|
|
||||||
|
def get_missing_env_vars(self) -> list[str]:
|
||||||
|
missing_env_vars = []
|
||||||
|
for env_var in self.required_env_vars:
|
||||||
|
env_var_value = getattr(SettingsManager.get_settings(), env_var, None)
|
||||||
|
if not env_var_value:
|
||||||
|
missing_env_vars.append(env_var)
|
||||||
|
|
||||||
|
return missing_env_vars
|
||||||
|
|
||||||
|
|
||||||
|
class LLMAPIHandler(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
step: Step | None = None,
|
||||||
|
screenshots: list[bytes] | None = None,
|
||||||
|
parameters: dict[str, Any] | None = None,
|
||||||
|
) -> Awaitable[dict[str, Any]]:
|
||||||
|
...
|
||||||
45
skyvern/forge/sdk/api/llm/utils.py
Normal file
45
skyvern/forge/sdk/api/llm/utils.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import base64
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import commentjson
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_messages_builder(
|
||||||
|
prompt: str,
|
||||||
|
screenshots: list[bytes] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if screenshots:
|
||||||
|
for screenshot in screenshots:
|
||||||
|
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{encoded_image}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
|
||||||
|
try:
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
content = content.replace("```json", "")
|
||||||
|
content = content.replace("```", "")
|
||||||
|
if not content:
|
||||||
|
raise EmptyLLMResponseError(str(response))
|
||||||
|
return commentjson.loads(content)
|
||||||
|
except Exception as e:
|
||||||
|
raise InvalidLLMResponseFormat(str(response)) from e
|
||||||
@@ -1,228 +0,0 @@
|
|||||||
import base64
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import commentjson
|
|
||||||
import openai
|
|
||||||
import structlog
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
|
||||||
|
|
||||||
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
|
|
||||||
from skyvern.forge import app
|
|
||||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
|
||||||
from skyvern.forge.sdk.models import Step
|
|
||||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIKeyClientWrapper:
|
|
||||||
client: AsyncOpenAI
|
|
||||||
key: str
|
|
||||||
remaining_requests: int | None
|
|
||||||
|
|
||||||
def __init__(self, key: str, remaining_requests: int | None) -> None:
|
|
||||||
self.key = key
|
|
||||||
self.remaining_requests = remaining_requests
|
|
||||||
self.updated_at = datetime.utcnow()
|
|
||||||
self.client = AsyncOpenAI(api_key=self.key)
|
|
||||||
|
|
||||||
def update_remaining_requests(self, remaining_requests: int | None) -> None:
|
|
||||||
self.remaining_requests = remaining_requests
|
|
||||||
self.updated_at = datetime.utcnow()
|
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
|
||||||
# If remaining_requests is None, then it's the first time we're trying this key
|
|
||||||
# so we can assume it's available, otherwise we check if it's greater than 0
|
|
||||||
if self.remaining_requests is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if self.remaining_requests > 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If we haven't checked this in over 1 minutes, check it again
|
|
||||||
# Most of our failures are because of Tokens-per-minute (TPM) limits
|
|
||||||
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClientManager:
|
|
||||||
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
|
|
||||||
clients: list[OpenAIKeyClientWrapper]
|
|
||||||
|
|
||||||
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
|
|
||||||
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
|
|
||||||
|
|
||||||
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
|
|
||||||
available_clients = [client for client in self.clients if client.is_available()]
|
|
||||||
|
|
||||||
if not available_clients:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Randomly select an available client to distribute requests across our accounts
|
|
||||||
return random.choice(available_clients)
|
|
||||||
|
|
||||||
async def content_builder(
|
|
||||||
self,
|
|
||||||
step: Step,
|
|
||||||
screenshots: list[bytes] | None = None,
|
|
||||||
prompt: str | None = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
content: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
if prompt is not None:
|
|
||||||
content.append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
|
||||||
step=step,
|
|
||||||
artifact_type=ArtifactType.LLM_PROMPT,
|
|
||||||
data=prompt.encode("utf-8"),
|
|
||||||
)
|
|
||||||
if screenshots:
|
|
||||||
for screenshot in screenshots:
|
|
||||||
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
|
||||||
content.append(
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/png;base64,{encoded_image}",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# create artifact for each image
|
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
|
||||||
step=step,
|
|
||||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
|
||||||
data=screenshot,
|
|
||||||
)
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
|
||||||
step: Step,
|
|
||||||
model: str = "gpt-4-vision-preview",
|
|
||||||
max_tokens: int = 4096,
|
|
||||||
temperature: int = 0,
|
|
||||||
screenshots: list[bytes] | None = None,
|
|
||||||
prompt: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
LOG.info(
|
|
||||||
f"Sending LLM request",
|
|
||||||
task_id=step.task_id,
|
|
||||||
step_id=step.step_id,
|
|
||||||
num_screenshots=len(screenshots) if screenshots else 0,
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": await self.content_builder(
|
|
||||||
step=step,
|
|
||||||
screenshots=screenshots,
|
|
||||||
prompt=prompt,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
chat_completion_kwargs = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
|
||||||
step=step,
|
|
||||||
artifact_type=ArtifactType.LLM_REQUEST,
|
|
||||||
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
|
|
||||||
)
|
|
||||||
available_client = self.get_available_client()
|
|
||||||
if available_client is None:
|
|
||||||
raise NoAvailableOpenAIClients()
|
|
||||||
try:
|
|
||||||
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
# If we get a RateLimitError, we can assume the key is not available anymore
|
|
||||||
if e.code == 429:
|
|
||||||
raise OpenAIRequestTooBigError(e.message)
|
|
||||||
LOG.warning(
|
|
||||||
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
|
|
||||||
)
|
|
||||||
available_client.update_remaining_requests(remaining_requests=0)
|
|
||||||
available_client = self.get_available_client()
|
|
||||||
if available_client is None:
|
|
||||||
raise NoAvailableOpenAIClients()
|
|
||||||
return await self.chat_completion(
|
|
||||||
step=step,
|
|
||||||
model=model,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
screenshots=screenshots,
|
|
||||||
prompt=prompt,
|
|
||||||
)
|
|
||||||
except openai.OpenAIError as e:
|
|
||||||
LOG.error("OpenAI error", exc_info=True)
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
LOG.error("Unknown error for chat completion", error_message=str(e), error_type=type(e))
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
|
|
||||||
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
|
|
||||||
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
|
|
||||||
|
|
||||||
# If we get a response, we can assume the key is available and update the remaining requests
|
|
||||||
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
|
|
||||||
|
|
||||||
if not ratelimit_remaining_requests:
|
|
||||||
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
|
|
||||||
|
|
||||||
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
|
|
||||||
chat_completion = response.parse()
|
|
||||||
|
|
||||||
if chat_completion.usage is not None:
|
|
||||||
# TODO (Suchintan): Is this bad design?
|
|
||||||
step = await app.DATABASE.update_step(
|
|
||||||
step_id=step.step_id,
|
|
||||||
task_id=step.task_id,
|
|
||||||
organization_id=step.organization_id,
|
|
||||||
chat_completion_price=ChatCompletionPrice(
|
|
||||||
input_token_count=chat_completion.usage.prompt_tokens,
|
|
||||||
output_token_count=chat_completion.usage.completion_tokens,
|
|
||||||
model_name=model,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
|
||||||
step=step,
|
|
||||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
|
||||||
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
|
|
||||||
)
|
|
||||||
parsed_response = self.parse_response(chat_completion)
|
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
|
||||||
step=step,
|
|
||||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
|
||||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
|
||||||
)
|
|
||||||
return parsed_response
|
|
||||||
|
|
||||||
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
|
|
||||||
try:
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
content = content.replace("```json", "")
|
|
||||||
content = content.replace("```", "")
|
|
||||||
if not content:
|
|
||||||
raise Exception("openai response content is empty")
|
|
||||||
return commentjson.loads(content)
|
|
||||||
except Exception as e:
|
|
||||||
raise InvalidOpenAIResponseFormat(str(response)) from e
|
|
||||||
@@ -7,7 +7,6 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from skyvern.exceptions import WorkflowParameterNotFound
|
from skyvern.exceptions import WorkflowParameterNotFound
|
||||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
|
||||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
||||||
@@ -264,7 +263,7 @@ class AgentDB:
|
|||||||
is_last: bool | None = None,
|
is_last: bool | None = None,
|
||||||
retry_index: int | None = None,
|
retry_index: int | None = None,
|
||||||
organization_id: str | None = None,
|
organization_id: str | None = None,
|
||||||
chat_completion_price: ChatCompletionPrice | None = None,
|
incremental_cost: float | None = None,
|
||||||
) -> Step:
|
) -> Step:
|
||||||
try:
|
try:
|
||||||
with self.Session() as session:
|
with self.Session() as session:
|
||||||
@@ -283,18 +282,8 @@ class AgentDB:
|
|||||||
step.is_last = is_last
|
step.is_last = is_last
|
||||||
if retry_index is not None:
|
if retry_index is not None:
|
||||||
step.retry_index = retry_index
|
step.retry_index = retry_index
|
||||||
if chat_completion_price is not None:
|
if incremental_cost is not None:
|
||||||
if step.input_token_count is None:
|
step.step_cost = incremental_cost + float(step.step_cost or 0)
|
||||||
step.input_token_count = 0
|
|
||||||
|
|
||||||
if step.output_token_count is None:
|
|
||||||
step.output_token_count = 0
|
|
||||||
|
|
||||||
step.input_token_count += chat_completion_price.input_token_count
|
|
||||||
step.output_token_count += chat_completion_price.output_token_count
|
|
||||||
step.step_cost = chat_completion_price.openai_model_to_price_lambda(
|
|
||||||
step.input_token_count, step.output_token_count
|
|
||||||
)
|
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
updated_step = await self.get_step(task_id, step_id, organization_id)
|
updated_step = await self.get_step(task_id, step_id, organization_id)
|
||||||
|
|||||||
@@ -573,9 +573,9 @@ async def extract_information_for_navigation_goal(
|
|||||||
error_code_mapping_str=json.dumps(task.error_code_mapping) if task.error_code_mapping else None,
|
error_code_mapping_str=json.dumps(task.error_code_mapping) if task.error_code_mapping else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
json_response = await app.OPENAI_CLIENT.chat_completion(
|
json_response = await app.LLM_API_HANDLER(
|
||||||
step=step,
|
|
||||||
prompt=extract_information_prompt,
|
prompt=extract_information_prompt,
|
||||||
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user