Implement LLM router (#95)
This commit is contained in:
@@ -1,27 +0,0 @@
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
openai_model_to_price_lambdas = {
|
||||
"gpt-4-vision-preview": (0.01, 0.03),
|
||||
"gpt-4-1106-preview": (0.01, 0.03),
|
||||
"gpt-4-0125-preview": (0.01, 0.03),
|
||||
"gpt-3.5-turbo": (0.001, 0.002),
|
||||
"gpt-3.5-turbo-1106": (0.001, 0.002),
|
||||
"gpt-3.5-turbo-0125": (0.0005, 0.0015),
|
||||
}
|
||||
|
||||
|
||||
class ChatCompletionPrice(BaseModel):
|
||||
input_token_count: int
|
||||
output_token_count: int
|
||||
openai_model_to_price_lambda: Callable[[int, int], float]
|
||||
|
||||
def __init__(self, input_token_count: int, output_token_count: int, model_name: str):
|
||||
input_token_price, output_token_price = openai_model_to_price_lambdas[model_name]
|
||||
super().__init__(
|
||||
input_token_count=input_token_count,
|
||||
output_token_count=output_token_count,
|
||||
openai_model_to_price_lambda=lambda input_token, output_token: input_token_price * input_token / 1000
|
||||
+ output_token_price * output_token / 1000,
|
||||
)
|
||||
0
skyvern/forge/sdk/api/llm/__init__.py
Normal file
0
skyvern/forge/sdk/api/llm/__init__.py
Normal file
115
skyvern/forge/sdk/api/llm/api_handler_factory.py
Normal file
115
skyvern/forge/sdk/api/llm/api_handler_factory.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
import openai
|
||||
import structlog
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
|
||||
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
|
||||
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
|
||||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LLMAPIHandlerFactory:
|
||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
|
||||
llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||
|
||||
async def llm_api_handler(
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if parameters is None:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters()
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
|
||||
if not llm_config.supports_vision:
|
||||
screenshots = None
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": llm_config.model_name,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
)
|
||||
try:
|
||||
# TODO (kerem): add a timeout to this call
|
||||
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
|
||||
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
|
||||
response = await litellm.acompletion(
|
||||
model=llm_config.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
raise LLMProviderError(llm_key) from e
|
||||
except Exception as e:
|
||||
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
|
||||
raise LLMProviderError(llm_key) from e
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
await app.DATABASE.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
incremental_cost=llm_cost,
|
||||
)
|
||||
parsed_response = parse_api_response(response)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
return llm_api_handler
|
||||
|
||||
@staticmethod
|
||||
def get_api_parameters() -> dict[str, Any]:
|
||||
return {
|
||||
"max_tokens": SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS,
|
||||
"temperature": SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_custom_handler(cls, llm_key: str, handler: LLMAPIHandler) -> None:
|
||||
if llm_key in cls._custom_handlers:
|
||||
raise DuplicateCustomLLMProviderError(llm_key)
|
||||
cls._custom_handlers[llm_key] = handler
|
||||
70
skyvern/forge/sdk/api/llm/config_registry.py
Normal file
70
skyvern/forge/sdk/api/llm/config_registry.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import structlog
|
||||
|
||||
from skyvern.forge.sdk.api.llm.exceptions import (
|
||||
DuplicateLLMConfigError,
|
||||
InvalidLLMConfigError,
|
||||
MissingLLMProviderEnvVarsError,
|
||||
NoProviderEnabledError,
|
||||
)
|
||||
from skyvern.forge.sdk.api.llm.models import LLMConfig
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LLMConfigRegistry:
|
||||
_configs: dict[str, LLMConfig] = {}
|
||||
|
||||
@staticmethod
|
||||
def validate_config(llm_key: str, config: LLMConfig) -> None:
|
||||
missing_env_vars = config.get_missing_env_vars()
|
||||
if missing_env_vars:
|
||||
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)
|
||||
|
||||
@classmethod
|
||||
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
|
||||
if llm_key in cls._configs:
|
||||
raise DuplicateLLMConfigError(llm_key)
|
||||
|
||||
cls.validate_config(llm_key, config)
|
||||
|
||||
LOG.info("Registering LLM config", llm_key=llm_key)
|
||||
cls._configs[llm_key] = config
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, llm_key: str) -> LLMConfig:
|
||||
if llm_key not in cls._configs:
|
||||
raise InvalidLLMConfigError(llm_key)
|
||||
|
||||
return cls._configs[llm_key]
|
||||
|
||||
|
||||
# if none of the LLM providers are enabled, raise an error
|
||||
if not any(
|
||||
[
|
||||
SettingsManager.get_settings().ENABLE_OPENAI,
|
||||
SettingsManager.get_settings().ENABLE_ANTHROPIC,
|
||||
SettingsManager.get_settings().ENABLE_AZURE,
|
||||
]
|
||||
):
|
||||
raise NoProviderEnabledError()
|
||||
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_OPENAI:
|
||||
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo-preview", ["OPENAI_API_KEY"], False))
|
||||
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-vision-preview", ["OPENAI_API_KEY"], True))
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
|
||||
LLMConfigRegistry.register_config(
|
||||
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
|
||||
)
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_AZURE:
|
||||
LLMConfigRegistry.register_config(
|
||||
"AZURE_OPENAI_GPT4V",
|
||||
LLMConfig(
|
||||
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
|
||||
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
|
||||
True,
|
||||
),
|
||||
)
|
||||
48
skyvern/forge/sdk/api/llm/exceptions.py
Normal file
48
skyvern/forge/sdk/api/llm/exceptions.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from skyvern.exceptions import SkyvernException
|
||||
|
||||
|
||||
class BaseLLMError(SkyvernException):
|
||||
pass
|
||||
|
||||
|
||||
class MissingLLMProviderEnvVarsError(BaseLLMError):
|
||||
def __init__(self, llm_key: str, missing_env_vars: list[str]) -> None:
|
||||
super().__init__(f"Environment variables {','.join(missing_env_vars)} are required for LLMProvider {llm_key}")
|
||||
|
||||
|
||||
class EmptyLLMResponseError(BaseLLMError):
|
||||
def __init__(self, response: str) -> None:
|
||||
super().__init__(f"LLM response content is empty: {response}")
|
||||
|
||||
|
||||
class InvalidLLMResponseFormat(BaseLLMError):
|
||||
def __init__(self, response: str) -> None:
|
||||
super().__init__(f"LLM response content is not a valid JSON: {response}")
|
||||
|
||||
|
||||
class DuplicateCustomLLMProviderError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"Custom LLMProvider {llm_key} is already registered")
|
||||
|
||||
|
||||
class DuplicateLLMConfigError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"LLM config with key {llm_key} is already registered")
|
||||
|
||||
|
||||
class InvalidLLMConfigError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"LLM config with key {llm_key} is not a valid LLMConfig")
|
||||
|
||||
|
||||
class LLMProviderError(BaseLLMError):
|
||||
def __init__(self, llm_key: str) -> None:
|
||||
super().__init__(f"Error while using LLMProvider {llm_key}")
|
||||
|
||||
|
||||
class NoProviderEnabledError(BaseLLMError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"At least one LLM provider must be enabled. Run setup.sh and follow through the LLM provider setup, or "
|
||||
"update the .env file (check out .env.example to see the required environment variables)."
|
||||
)
|
||||
32
skyvern/forge/sdk/api/llm/models.py
Normal file
32
skyvern/forge/sdk/api/llm/models.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Protocol
|
||||
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
|
||||
def get_missing_env_vars(self) -> list[str]:
|
||||
missing_env_vars = []
|
||||
for env_var in self.required_env_vars:
|
||||
env_var_value = getattr(SettingsManager.get_settings(), env_var, None)
|
||||
if not env_var_value:
|
||||
missing_env_vars.append(env_var)
|
||||
|
||||
return missing_env_vars
|
||||
|
||||
|
||||
class LLMAPIHandler(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> Awaitable[dict[str, Any]]:
|
||||
...
|
||||
45
skyvern/forge/sdk/api/llm/utils.py
Normal file
45
skyvern/forge/sdk/api/llm/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
import litellm
|
||||
|
||||
from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat
|
||||
|
||||
|
||||
async def llm_messages_builder(
|
||||
prompt: str,
|
||||
screenshots: list[bytes] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
]
|
||||
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
messages.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{encoded_image}",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
if not content:
|
||||
raise EmptyLLMResponseError(str(response))
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidLLMResponseFormat(str(response)) from e
|
||||
@@ -1,228 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
import openai
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from skyvern.exceptions import InvalidOpenAIResponseFormat, NoAvailableOpenAIClients, OpenAIRequestTooBigError
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.api.chat_completion_price import ChatCompletionPrice
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class OpenAIKeyClientWrapper:
|
||||
client: AsyncOpenAI
|
||||
key: str
|
||||
remaining_requests: int | None
|
||||
|
||||
def __init__(self, key: str, remaining_requests: int | None) -> None:
|
||||
self.key = key
|
||||
self.remaining_requests = remaining_requests
|
||||
self.updated_at = datetime.utcnow()
|
||||
self.client = AsyncOpenAI(api_key=self.key)
|
||||
|
||||
def update_remaining_requests(self, remaining_requests: int | None) -> None:
|
||||
self.remaining_requests = remaining_requests
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# If remaining_requests is None, then it's the first time we're trying this key
|
||||
# so we can assume it's available, otherwise we check if it's greater than 0
|
||||
if self.remaining_requests is None:
|
||||
return True
|
||||
|
||||
if self.remaining_requests > 0:
|
||||
return True
|
||||
|
||||
# If we haven't checked this in over 1 minutes, check it again
|
||||
# Most of our failures are because of Tokens-per-minute (TPM) limits
|
||||
if self.updated_at < (datetime.utcnow() - timedelta(minutes=1)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class OpenAIClientManager:
|
||||
# TODO Support other models for requests without screenshots, track rate limits for each model and key as well if any
|
||||
clients: list[OpenAIKeyClientWrapper]
|
||||
|
||||
def __init__(self, api_keys: list[str] = SettingsManager.get_settings().OPENAI_API_KEYS) -> None:
|
||||
self.clients = [OpenAIKeyClientWrapper(key, None) for key in api_keys]
|
||||
|
||||
def get_available_client(self) -> OpenAIKeyClientWrapper | None:
|
||||
available_clients = [client for client in self.clients if client.is_available()]
|
||||
|
||||
if not available_clients:
|
||||
return None
|
||||
|
||||
# Randomly select an available client to distribute requests across our accounts
|
||||
return random.choice(available_clients)
|
||||
|
||||
async def content_builder(
|
||||
self,
|
||||
step: Step,
|
||||
screenshots: list[bytes] | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
content: list[dict[str, Any]] = []
|
||||
|
||||
if prompt is not None:
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
}
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
if screenshots:
|
||||
for screenshot in screenshots:
|
||||
encoded_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{encoded_image}",
|
||||
},
|
||||
}
|
||||
)
|
||||
# create artifact for each image
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
step: Step,
|
||||
model: str = "gpt-4-vision-preview",
|
||||
max_tokens: int = 4096,
|
||||
temperature: int = 0,
|
||||
screenshots: list[bytes] | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
LOG.info(
|
||||
f"Sending LLM request",
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
num_screenshots=len(screenshots) if screenshots else 0,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": await self.content_builder(
|
||||
step=step,
|
||||
screenshots=screenshots,
|
||||
prompt=prompt,
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
chat_completion_kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(chat_completion_kwargs).encode("utf-8"),
|
||||
)
|
||||
available_client = self.get_available_client()
|
||||
if available_client is None:
|
||||
raise NoAvailableOpenAIClients()
|
||||
try:
|
||||
response = await available_client.client.chat.completions.with_raw_response.create(**chat_completion_kwargs)
|
||||
except openai.RateLimitError as e:
|
||||
# If we get a RateLimitError, we can assume the key is not available anymore
|
||||
if e.code == 429:
|
||||
raise OpenAIRequestTooBigError(e.message)
|
||||
LOG.warning(
|
||||
"OpenAI rate limit exceeded, marking key as unavailable.", error_code=e.code, error_message=e.message
|
||||
)
|
||||
available_client.update_remaining_requests(remaining_requests=0)
|
||||
available_client = self.get_available_client()
|
||||
if available_client is None:
|
||||
raise NoAvailableOpenAIClients()
|
||||
return await self.chat_completion(
|
||||
step=step,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
screenshots=screenshots,
|
||||
prompt=prompt,
|
||||
)
|
||||
except openai.OpenAIError as e:
|
||||
LOG.error("OpenAI error", exc_info=True)
|
||||
raise e
|
||||
except Exception as e:
|
||||
LOG.error("Unknown error for chat completion", error_message=str(e), error_type=type(e))
|
||||
raise e
|
||||
|
||||
# TODO: https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers
|
||||
# use other headers, x-ratelimit-limit-requests, x-ratelimit-limit-tokens, x-ratelimit-remaining-tokens
|
||||
# x-ratelimit-reset-requests, x-ratelimit-reset-tokens to write a more accurate algorithm for managing api keys
|
||||
|
||||
# If we get a response, we can assume the key is available and update the remaining requests
|
||||
ratelimit_remaining_requests = response.headers.get("x-ratelimit-remaining-requests")
|
||||
|
||||
if not ratelimit_remaining_requests:
|
||||
LOG.warning("Invalid x-ratelimit-remaining-requests from OpenAI", response.headers)
|
||||
|
||||
available_client.update_remaining_requests(remaining_requests=int(ratelimit_remaining_requests))
|
||||
chat_completion = response.parse()
|
||||
|
||||
if chat_completion.usage is not None:
|
||||
# TODO (Suchintan): Is this bad design?
|
||||
step = await app.DATABASE.update_step(
|
||||
step_id=step.step_id,
|
||||
task_id=step.task_id,
|
||||
organization_id=step.organization_id,
|
||||
chat_completion_price=ChatCompletionPrice(
|
||||
input_token_count=chat_completion.usage.prompt_tokens,
|
||||
output_token_count=chat_completion.usage.completion_tokens,
|
||||
model_name=model,
|
||||
),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=chat_completion.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
parsed_response = self.parse_response(chat_completion)
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
def parse_response(self, response: ChatCompletion) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
if not content:
|
||||
raise Exception("openai response content is empty")
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidOpenAIResponseFormat(str(response)) from e
|
||||
Reference in New Issue
Block a user