Define haiku + prompt engine takes a directory arg (#279)
This commit is contained in:
@@ -21,11 +21,14 @@ class PromptEngine:
|
||||
import glob
|
||||
import os
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import structlog
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from skyvern.constants import SKYVERN_DIR
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -34,7 +37,7 @@ class PromptEngine:
|
||||
Class to handle loading and populating Jinja2 templates for prompts.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str):
|
||||
def __init__(self, model: str, prompts_dir: Path = SKYVERN_DIR / "forge" / "prompts") -> None:
|
||||
"""
|
||||
Initialize the PromptEngine with the specified model.
|
||||
|
||||
@@ -45,7 +48,7 @@ class PromptEngine:
|
||||
|
||||
try:
|
||||
# Get the list of all model directories
|
||||
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts"))
|
||||
models_dir = os.path.abspath(prompts_dir)
|
||||
model_names = [
|
||||
os.path.basename(os.path.normpath(d))
|
||||
for d in glob.glob(os.path.join(models_dir, "*/"))
|
||||
|
||||
Reference in New Issue
Block a user