use cached prompt generation (#768)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -0,0 +1,46 @@
|
|||||||
|
"""update task_generation table - use user_prompt_hash as the index of a user prompt
|
||||||
|
|
||||||
|
Revision ID: 0de9150bc624
|
||||||
|
Revises: 6de11b2be7c8
|
||||||
|
Create Date: 2024-09-03 03:56:58.352307+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0de9150bc624"
|
||||||
|
down_revision: Union[str, None] = "6de11b2be7c8"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column("task_generations", sa.Column("user_prompt_hash", sa.String(), nullable=True))
|
||||||
|
op.add_column("task_generations", sa.Column("source_task_generation_id", sa.String(), nullable=True))
|
||||||
|
op.drop_index("ix_task_generations_user_prompt", table_name="task_generations")
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_task_generations_source_task_generation_id"),
|
||||||
|
"task_generations",
|
||||||
|
["source_task_generation_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_task_generations_user_prompt_hash"), "task_generations", ["user_prompt_hash"], unique=False
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f("ix_task_generations_user_prompt_hash"), table_name="task_generations")
|
||||||
|
op.drop_index(op.f("ix_task_generations_source_task_generation_id"), table_name="task_generations")
|
||||||
|
op.create_index("ix_task_generations_user_prompt", "task_generations", ["user_prompt"], unique=False)
|
||||||
|
op.drop_column("task_generations", "source_task_generation_id")
|
||||||
|
op.drop_column("task_generations", "user_prompt_hash")
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -77,6 +77,9 @@ class Settings(BaseSettings):
|
|||||||
BITWARDEN_TIMEOUT_SECONDS: int = 60
|
BITWARDEN_TIMEOUT_SECONDS: int = 60
|
||||||
BITWARDEN_MAX_RETRIES: int = 1
|
BITWARDEN_MAX_RETRIES: int = 1
|
||||||
|
|
||||||
|
# task generation settings
|
||||||
|
PROMPT_CACHE_WINDOW_HOURS: int = 24
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# LLM Configuration #
|
# LLM Configuration #
|
||||||
#####################
|
#####################
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Sequence
|
from typing import Any, Sequence
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
@@ -6,6 +6,7 @@ from sqlalchemy import and_, delete, func, select, update
|
|||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
from skyvern.exceptions import WorkflowParameterNotFound
|
from skyvern.exceptions import WorkflowParameterNotFound
|
||||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
@@ -1386,6 +1387,7 @@ class AgentDB:
|
|||||||
self,
|
self,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
|
user_prompt_hash: str,
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
navigation_goal: str | None = None,
|
navigation_goal: str | None = None,
|
||||||
navigation_payload: dict[str, Any] | None = None,
|
navigation_payload: dict[str, Any] | None = None,
|
||||||
@@ -1395,11 +1397,13 @@ class AgentDB:
|
|||||||
llm: str | None = None,
|
llm: str | None = None,
|
||||||
llm_prompt: str | None = None,
|
llm_prompt: str | None = None,
|
||||||
llm_response: str | None = None,
|
llm_response: str | None = None,
|
||||||
|
source_task_generation_id: str | None = None,
|
||||||
) -> TaskGeneration:
|
) -> TaskGeneration:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
new_task_generation = TaskGenerationModel(
|
new_task_generation = TaskGenerationModel(
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
|
user_prompt_hash=user_prompt_hash,
|
||||||
url=url,
|
url=url,
|
||||||
navigation_goal=navigation_goal,
|
navigation_goal=navigation_goal,
|
||||||
navigation_payload=navigation_payload,
|
navigation_payload=navigation_payload,
|
||||||
@@ -1409,8 +1413,27 @@ class AgentDB:
|
|||||||
llm_prompt=llm_prompt,
|
llm_prompt=llm_prompt,
|
||||||
llm_response=llm_response,
|
llm_response=llm_response,
|
||||||
suggested_title=suggested_title,
|
suggested_title=suggested_title,
|
||||||
|
source_task_generation_id=source_task_generation_id,
|
||||||
)
|
)
|
||||||
session.add(new_task_generation)
|
session.add(new_task_generation)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(new_task_generation)
|
await session.refresh(new_task_generation)
|
||||||
return TaskGeneration.model_validate(new_task_generation)
|
return TaskGeneration.model_validate(new_task_generation)
|
||||||
|
|
||||||
|
async def get_task_generation_by_prompt_hash(
|
||||||
|
self,
|
||||||
|
user_prompt_hash: str,
|
||||||
|
query_window_hours: int = settings.PROMPT_ACTION_HISTORY_WINDOW,
|
||||||
|
) -> TaskGeneration | None:
|
||||||
|
before_time = datetime.utcnow() - timedelta(hours=query_window_hours)
|
||||||
|
async with self.Session() as session:
|
||||||
|
query = (
|
||||||
|
select(TaskGenerationModel)
|
||||||
|
.filter_by(user_prompt_hash=user_prompt_hash)
|
||||||
|
.filter(TaskGenerationModel.llm.is_not(None))
|
||||||
|
.filter(TaskGenerationModel.created_at > before_time)
|
||||||
|
)
|
||||||
|
task_generation = (await session.scalars(query)).first()
|
||||||
|
if not task_generation:
|
||||||
|
return None
|
||||||
|
return TaskGeneration.model_validate(task_generation)
|
||||||
|
|||||||
@@ -374,7 +374,8 @@ class TaskGenerationModel(Base):
|
|||||||
|
|
||||||
task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id)
|
task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id)
|
||||||
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False)
|
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False)
|
||||||
user_prompt = Column(String, nullable=False, index=True) # The prompt from the user
|
user_prompt = Column(String, nullable=False)
|
||||||
|
user_prompt_hash = Column(String, index=True)
|
||||||
url = Column(String)
|
url = Column(String)
|
||||||
navigation_goal = Column(String)
|
navigation_goal = Column(String)
|
||||||
navigation_payload = Column(JSON)
|
navigation_payload = Column(JSON)
|
||||||
@@ -386,5 +387,7 @@ class TaskGenerationModel(Base):
|
|||||||
llm_prompt = Column(String) # The prompt sent to the language model
|
llm_prompt = Column(String) # The prompt sent to the language model
|
||||||
llm_response = Column(String) # The response from the language model
|
llm_response = Column(String) # The response from the language model
|
||||||
|
|
||||||
|
source_task_generation_id = Column(String, index=True)
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
|||||||
base_router = APIRouter()
|
base_router = APIRouter()
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
PROMPT_CACHE_WINDOW_HOURS = 24
|
||||||
|
|
||||||
|
|
||||||
@base_router.post("/webhook", tags=["server"])
|
@base_router.post("/webhook", tags=["server"])
|
||||||
@@ -766,6 +768,32 @@ async def generate_task(
|
|||||||
data: GenerateTaskRequest,
|
data: GenerateTaskRequest,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> TaskGeneration:
|
) -> TaskGeneration:
|
||||||
|
user_prompt = data.prompt
|
||||||
|
hash_object = hashlib.sha256()
|
||||||
|
hash_object.update(user_prompt.encode("utf-8"))
|
||||||
|
user_prompt_hash = hash_object.hexdigest()
|
||||||
|
# check if there's a same user_prompt within the past x Hours
|
||||||
|
# in the future, we can use vector db to fetch similar prompts
|
||||||
|
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
|
||||||
|
user_prompt_hash=user_prompt_hash, query_window_hours=PROMPT_CACHE_WINDOW_HOURS
|
||||||
|
)
|
||||||
|
if existing_task_generation:
|
||||||
|
new_task_generation = await app.DATABASE.create_task_generation(
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
|
user_prompt=data.prompt,
|
||||||
|
user_prompt_hash=user_prompt_hash,
|
||||||
|
url=existing_task_generation.url,
|
||||||
|
navigation_goal=existing_task_generation.navigation_goal,
|
||||||
|
navigation_payload=existing_task_generation.navigation_payload,
|
||||||
|
data_extraction_goal=existing_task_generation.data_extraction_goal,
|
||||||
|
extracted_information_schema=existing_task_generation.extracted_information_schema,
|
||||||
|
llm=existing_task_generation.llm,
|
||||||
|
llm_prompt=existing_task_generation.llm_prompt,
|
||||||
|
llm_response=existing_task_generation.llm_response,
|
||||||
|
source_task_generation_id=existing_task_generation.task_generation_id,
|
||||||
|
)
|
||||||
|
return new_task_generation
|
||||||
|
|
||||||
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
|
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
|
||||||
try:
|
try:
|
||||||
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
|
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
|
||||||
@@ -775,6 +803,7 @@ async def generate_task(
|
|||||||
task_generation = await app.DATABASE.create_task_generation(
|
task_generation = await app.DATABASE.create_task_generation(
|
||||||
organization_id=current_org.organization_id,
|
organization_id=current_org.organization_id,
|
||||||
user_prompt=data.prompt,
|
user_prompt=data.prompt,
|
||||||
|
user_prompt_hash=user_prompt_hash,
|
||||||
url=parsed_task_generation_obj.url,
|
url=parsed_task_generation_obj.url,
|
||||||
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
||||||
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
class TaskGenerationBase(BaseModel):
|
class TaskGenerationBase(BaseModel):
|
||||||
@@ -9,6 +9,7 @@ class TaskGenerationBase(BaseModel):
|
|||||||
|
|
||||||
organization_id: str | None = None
|
organization_id: str | None = None
|
||||||
user_prompt: str | None = None
|
user_prompt: str | None = None
|
||||||
|
user_prompt_hash: str | None = None
|
||||||
url: str | None = None
|
url: str | None = None
|
||||||
navigation_goal: str | None = None
|
navigation_goal: str | None = None
|
||||||
navigation_payload: dict[str, Any] | None = None
|
navigation_payload: dict[str, Any] | None = None
|
||||||
@@ -20,19 +21,16 @@ class TaskGenerationBase(BaseModel):
|
|||||||
suggested_title: str | None = None
|
suggested_title: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class TaskGenerationCreate(TaskGenerationBase):
|
|
||||||
organization_id: str
|
|
||||||
user_prompt: str
|
|
||||||
|
|
||||||
|
|
||||||
class TaskGeneration(TaskGenerationBase):
|
class TaskGeneration(TaskGenerationBase):
|
||||||
task_generation_id: str
|
task_generation_id: str
|
||||||
organization_id: str
|
organization_id: str
|
||||||
user_prompt: str
|
user_prompt: str
|
||||||
|
user_prompt_hash: str
|
||||||
|
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class GenerateTaskRequest(BaseModel):
|
class GenerateTaskRequest(BaseModel):
|
||||||
prompt: str
|
# prompt needs to be at least 1 character long
|
||||||
|
prompt: str = Field(..., min_length=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user