use cached prompt generation (#768)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz
2024-09-03 07:00:15 +03:00
committed by GitHub
parent 2097d01471
commit 0d39e62df6
6 changed files with 111 additions and 9 deletions

View File

@@ -0,0 +1,46 @@
"""update task_generation table - use user_prompt_hash as the index of a user prompt
Revision ID: 0de9150bc624
Revises: 6de11b2be7c8
Create Date: 2024-09-03 03:56:58.352307+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0de9150bc624"
down_revision: Union[str, None] = "6de11b2be7c8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("task_generations", sa.Column("user_prompt_hash", sa.String(), nullable=True))
op.add_column("task_generations", sa.Column("source_task_generation_id", sa.String(), nullable=True))
op.drop_index("ix_task_generations_user_prompt", table_name="task_generations")
op.create_index(
op.f("ix_task_generations_source_task_generation_id"),
"task_generations",
["source_task_generation_id"],
unique=False,
)
op.create_index(
op.f("ix_task_generations_user_prompt_hash"), "task_generations", ["user_prompt_hash"], unique=False
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_task_generations_user_prompt_hash"), table_name="task_generations")
op.drop_index(op.f("ix_task_generations_source_task_generation_id"), table_name="task_generations")
op.create_index("ix_task_generations_user_prompt", "task_generations", ["user_prompt"], unique=False)
op.drop_column("task_generations", "source_task_generation_id")
op.drop_column("task_generations", "user_prompt_hash")
# ### end Alembic commands ###

View File

@@ -77,6 +77,9 @@ class Settings(BaseSettings):
BITWARDEN_TIMEOUT_SECONDS: int = 60 BITWARDEN_TIMEOUT_SECONDS: int = 60
BITWARDEN_MAX_RETRIES: int = 1 BITWARDEN_MAX_RETRIES: int = 1
# task generation settings
PROMPT_CACHE_WINDOW_HOURS: int = 24
##################### #####################
# LLM Configuration # # LLM Configuration #
##################### #####################

View File

@@ -1,4 +1,4 @@
from datetime import datetime from datetime import datetime, timedelta
from typing import Any, Sequence from typing import Any, Sequence
import structlog import structlog
@@ -6,6 +6,7 @@ from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from skyvern.config import settings
from skyvern.exceptions import WorkflowParameterNotFound from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
@@ -1386,6 +1387,7 @@ class AgentDB:
self, self,
organization_id: str, organization_id: str,
user_prompt: str, user_prompt: str,
user_prompt_hash: str,
url: str | None = None, url: str | None = None,
navigation_goal: str | None = None, navigation_goal: str | None = None,
navigation_payload: dict[str, Any] | None = None, navigation_payload: dict[str, Any] | None = None,
@@ -1395,11 +1397,13 @@ class AgentDB:
llm: str | None = None, llm: str | None = None,
llm_prompt: str | None = None, llm_prompt: str | None = None,
llm_response: str | None = None, llm_response: str | None = None,
source_task_generation_id: str | None = None,
) -> TaskGeneration: ) -> TaskGeneration:
async with self.Session() as session: async with self.Session() as session:
new_task_generation = TaskGenerationModel( new_task_generation = TaskGenerationModel(
organization_id=organization_id, organization_id=organization_id,
user_prompt=user_prompt, user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
url=url, url=url,
navigation_goal=navigation_goal, navigation_goal=navigation_goal,
navigation_payload=navigation_payload, navigation_payload=navigation_payload,
@@ -1409,8 +1413,27 @@ class AgentDB:
llm_prompt=llm_prompt, llm_prompt=llm_prompt,
llm_response=llm_response, llm_response=llm_response,
suggested_title=suggested_title, suggested_title=suggested_title,
source_task_generation_id=source_task_generation_id,
) )
session.add(new_task_generation) session.add(new_task_generation)
await session.commit() await session.commit()
await session.refresh(new_task_generation) await session.refresh(new_task_generation)
return TaskGeneration.model_validate(new_task_generation) return TaskGeneration.model_validate(new_task_generation)
async def get_task_generation_by_prompt_hash(
self,
user_prompt_hash: str,
query_window_hours: int = settings.PROMPT_ACTION_HISTORY_WINDOW,
) -> TaskGeneration | None:
before_time = datetime.utcnow() - timedelta(hours=query_window_hours)
async with self.Session() as session:
query = (
select(TaskGenerationModel)
.filter_by(user_prompt_hash=user_prompt_hash)
.filter(TaskGenerationModel.llm.is_not(None))
.filter(TaskGenerationModel.created_at > before_time)
)
task_generation = (await session.scalars(query)).first()
if not task_generation:
return None
return TaskGeneration.model_validate(task_generation)

View File

@@ -374,7 +374,8 @@ class TaskGenerationModel(Base):
task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id) task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False) organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False)
user_prompt = Column(String, nullable=False, index=True) # The prompt from the user user_prompt = Column(String, nullable=False)
user_prompt_hash = Column(String, index=True)
url = Column(String) url = Column(String)
navigation_goal = Column(String) navigation_goal = Column(String)
navigation_payload = Column(JSON) navigation_payload = Column(JSON)
@@ -386,5 +387,7 @@ class TaskGenerationModel(Base):
llm_prompt = Column(String) # The prompt sent to the language model llm_prompt = Column(String) # The prompt sent to the language model
llm_response = Column(String) # The response from the language model llm_response = Column(String) # The response from the language model
source_task_generation_id = Column(String, index=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View File

@@ -1,4 +1,5 @@
import datetime import datetime
import hashlib
import uuid import uuid
from typing import Annotated, Any from typing import Annotated, Any
@@ -54,6 +55,7 @@ from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
base_router = APIRouter() base_router = APIRouter()
LOG = structlog.get_logger() LOG = structlog.get_logger()
PROMPT_CACHE_WINDOW_HOURS = 24
@base_router.post("/webhook", tags=["server"]) @base_router.post("/webhook", tags=["server"])
@@ -766,6 +768,32 @@ async def generate_task(
data: GenerateTaskRequest, data: GenerateTaskRequest,
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
) -> TaskGeneration: ) -> TaskGeneration:
user_prompt = data.prompt
hash_object = hashlib.sha256()
hash_object.update(user_prompt.encode("utf-8"))
user_prompt_hash = hash_object.hexdigest()
# check if there's a same user_prompt within the past x Hours
# in the future, we can use vector db to fetch similar prompts
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
user_prompt_hash=user_prompt_hash, query_window_hours=PROMPT_CACHE_WINDOW_HOURS
)
if existing_task_generation:
new_task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id,
user_prompt=data.prompt,
user_prompt_hash=user_prompt_hash,
url=existing_task_generation.url,
navigation_goal=existing_task_generation.navigation_goal,
navigation_payload=existing_task_generation.navigation_payload,
data_extraction_goal=existing_task_generation.data_extraction_goal,
extracted_information_schema=existing_task_generation.extracted_information_schema,
llm=existing_task_generation.llm,
llm_prompt=existing_task_generation.llm_prompt,
llm_response=existing_task_generation.llm_response,
source_task_generation_id=existing_task_generation.task_generation_id,
)
return new_task_generation
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt) llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
try: try:
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt) llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
@@ -775,6 +803,7 @@ async def generate_task(
task_generation = await app.DATABASE.create_task_generation( task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id, organization_id=current_org.organization_id,
user_prompt=data.prompt, user_prompt=data.prompt,
user_prompt_hash=user_prompt_hash,
url=parsed_task_generation_obj.url, url=parsed_task_generation_obj.url,
navigation_goal=parsed_task_generation_obj.navigation_goal, navigation_goal=parsed_task_generation_obj.navigation_goal,
navigation_payload=parsed_task_generation_obj.navigation_payload, navigation_payload=parsed_task_generation_obj.navigation_payload,

View File

@@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
class TaskGenerationBase(BaseModel): class TaskGenerationBase(BaseModel):
@@ -9,6 +9,7 @@ class TaskGenerationBase(BaseModel):
organization_id: str | None = None organization_id: str | None = None
user_prompt: str | None = None user_prompt: str | None = None
user_prompt_hash: str | None = None
url: str | None = None url: str | None = None
navigation_goal: str | None = None navigation_goal: str | None = None
navigation_payload: dict[str, Any] | None = None navigation_payload: dict[str, Any] | None = None
@@ -20,19 +21,16 @@ class TaskGenerationBase(BaseModel):
suggested_title: str | None = None suggested_title: str | None = None
class TaskGenerationCreate(TaskGenerationBase):
organization_id: str
user_prompt: str
class TaskGeneration(TaskGenerationBase): class TaskGeneration(TaskGenerationBase):
task_generation_id: str task_generation_id: str
organization_id: str organization_id: str
user_prompt: str user_prompt: str
user_prompt_hash: str
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime
class GenerateTaskRequest(BaseModel): class GenerateTaskRequest(BaseModel):
prompt: str # prompt needs to be at least 1 character long
prompt: str = Field(..., min_length=1)