Define haiku + prompt engine takes a directory arg (#279)

This commit is contained in:
Kerem Yilmaz
2024-05-08 02:07:18 -07:00
committed by GitHub
parent 42d652f381
commit e5d094493e
2 changed files with 16 additions and 2 deletions

View File

@@ -21,11 +21,14 @@ class PromptEngine:
import glob
import os
from difflib import get_close_matches
from pathlib import Path
from typing import Any, List
import structlog
from jinja2 import Environment, FileSystemLoader
from skyvern.constants import SKYVERN_DIR
LOG = structlog.get_logger()
@@ -34,7 +37,7 @@ class PromptEngine:
Class to handle loading and populating Jinja2 templates for prompts.
"""
def __init__(self, model: str):
def __init__(self, model: str, prompts_dir: Path = SKYVERN_DIR / "forge" / "prompts") -> None:
"""
Initialize the PromptEngine with the specified model.
@@ -45,7 +48,7 @@ class PromptEngine:
try:
# Get the list of all model directories
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts"))
models_dir = os.path.abspath(prompts_dir)
model_names = [
os.path.basename(os.path.normpath(d))
for d in glob.glob(os.path.join(models_dir, "*/"))