use cached prompt generation (#768)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import uuid
|
||||
from typing import Annotated, Any
|
||||
|
||||
@@ -54,6 +55,7 @@ from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
||||
base_router = APIRouter()
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
PROMPT_CACHE_WINDOW_HOURS = 24
|
||||
|
||||
|
||||
@base_router.post("/webhook", tags=["server"])
|
||||
@@ -766,6 +768,32 @@ async def generate_task(
|
||||
data: GenerateTaskRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> TaskGeneration:
|
||||
user_prompt = data.prompt
|
||||
hash_object = hashlib.sha256()
|
||||
hash_object.update(user_prompt.encode("utf-8"))
|
||||
user_prompt_hash = hash_object.hexdigest()
|
||||
# check if there's a same user_prompt within the past x Hours
|
||||
# in the future, we can use vector db to fetch similar prompts
|
||||
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
|
||||
user_prompt_hash=user_prompt_hash, query_window_hours=PROMPT_CACHE_WINDOW_HOURS
|
||||
)
|
||||
if existing_task_generation:
|
||||
new_task_generation = await app.DATABASE.create_task_generation(
|
||||
organization_id=current_org.organization_id,
|
||||
user_prompt=data.prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
url=existing_task_generation.url,
|
||||
navigation_goal=existing_task_generation.navigation_goal,
|
||||
navigation_payload=existing_task_generation.navigation_payload,
|
||||
data_extraction_goal=existing_task_generation.data_extraction_goal,
|
||||
extracted_information_schema=existing_task_generation.extracted_information_schema,
|
||||
llm=existing_task_generation.llm,
|
||||
llm_prompt=existing_task_generation.llm_prompt,
|
||||
llm_response=existing_task_generation.llm_response,
|
||||
source_task_generation_id=existing_task_generation.task_generation_id,
|
||||
)
|
||||
return new_task_generation
|
||||
|
||||
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
|
||||
try:
|
||||
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
|
||||
@@ -775,6 +803,7 @@ async def generate_task(
|
||||
task_generation = await app.DATABASE.create_task_generation(
|
||||
organization_id=current_org.organization_id,
|
||||
user_prompt=data.prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
url=parsed_task_generation_obj.url,
|
||||
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
||||
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
||||
|
||||
Reference in New Issue
Block a user