From e5d094493e8421f0f04f1108b6a2df837674a42c Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Wed, 8 May 2024 02:07:18 -0700 Subject: [PATCH] Define haiku + prompt engine takes a directory arg (#279) --- skyvern/forge/sdk/api/llm/config_registry.py | 11 +++++++++++ skyvern/forge/sdk/prompting.py | 7 +++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/skyvern/forge/sdk/api/llm/config_registry.py b/skyvern/forge/sdk/api/llm/config_registry.py index 24ffec42..6c750fed 100644 --- a/skyvern/forge/sdk/api/llm/config_registry.py +++ b/skyvern/forge/sdk/api/llm/config_registry.py @@ -65,6 +65,9 @@ if SettingsManager.get_settings().ENABLE_ANTHROPIC: LLMConfigRegistry.register_config( "ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True) ) + LLMConfigRegistry.register_config( + "ANTHROPIC_CLAUDE3_HAIKU", LLMConfig("anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], True) + ) if SettingsManager.get_settings().ENABLE_BEDROCK: # Supported through AWS IAM authentication @@ -84,6 +87,14 @@ if SettingsManager.get_settings().ENABLE_BEDROCK: True, ), ) + LLMConfigRegistry.register_config( + "BEDROCK_ANTHROPIC_CLAUDE3_HAIKU", + LLMConfig( + "bedrock/anthropic.claude-3-haiku-20240307-v1:0", + ["AWS_REGION"], + True, + ), + ) if SettingsManager.get_settings().ENABLE_AZURE: LLMConfigRegistry.register_config( diff --git a/skyvern/forge/sdk/prompting.py b/skyvern/forge/sdk/prompting.py index 9aa9f947..2d2c3fe0 100644 --- a/skyvern/forge/sdk/prompting.py +++ b/skyvern/forge/sdk/prompting.py @@ -21,11 +21,14 @@ class PromptEngine: import glob import os from difflib import get_close_matches +from pathlib import Path from typing import Any, List import structlog from jinja2 import Environment, FileSystemLoader +from skyvern.constants import SKYVERN_DIR + LOG = structlog.get_logger() @@ -34,7 +37,7 @@ class PromptEngine: Class to handle loading and populating Jinja2 templates for prompts. """ - def __init__(self, model: str): + def __init__(self, model: str, prompts_dir: Path = SKYVERN_DIR / "forge" / "prompts") -> None: """ Initialize the PromptEngine with the specified model. @@ -45,7 +48,7 @@ class PromptEngine: try: # Get the list of all model directories - models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts")) + models_dir = os.path.abspath(prompts_dir) model_names = [ os.path.basename(os.path.normpath(d)) for d in glob.glob(os.path.join(models_dir, "*/"))