task generation (#450)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from skyvern.forge.sdk.db.models import (
|
||||
OrganizationModel,
|
||||
OutputParameterModel,
|
||||
StepModel,
|
||||
TaskGenerationModel,
|
||||
TaskModel,
|
||||
WorkflowModel,
|
||||
WorkflowParameterModel,
|
||||
@@ -42,6 +43,7 @@ from skyvern.forge.sdk.db.utils import (
|
||||
convert_to_workflow_run_parameter,
|
||||
)
|
||||
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.parameter import (
|
||||
AWSSecretParameter,
|
||||
@@ -1236,3 +1238,34 @@ class AgentDB:
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def create_task_generation(
|
||||
self,
|
||||
organization_id: str,
|
||||
user_prompt: str,
|
||||
url: str | None = None,
|
||||
navigation_goal: str | None = None,
|
||||
navigation_payload: dict[str, Any] | None = None,
|
||||
data_extraction_goal: str | None = None,
|
||||
extracted_information_schema: dict[str, Any] | None = None,
|
||||
llm: str | None = None,
|
||||
llm_prompt: str | None = None,
|
||||
llm_response: str | None = None,
|
||||
) -> TaskGeneration:
|
||||
async with self.Session() as session:
|
||||
new_task_generation = TaskGenerationModel(
|
||||
organization_id=organization_id,
|
||||
user_prompt=user_prompt,
|
||||
url=url,
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
extracted_information_schema=extracted_information_schema,
|
||||
llm=llm,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response=llm_response,
|
||||
)
|
||||
session.add(new_task_generation)
|
||||
await session.commit()
|
||||
await session.refresh(new_task_generation)
|
||||
return TaskGeneration.model_validate(new_task_generation)
|
||||
|
||||
@@ -40,6 +40,7 @@ WORKFLOW_PARAMETER_PREFIX = "wp"
|
||||
AWS_SECRET_PARAMETER_PREFIX = "asp"
|
||||
OUTPUT_PARAMETER_PREFIX = "op"
|
||||
BITWARDEN_LOGIN_CREDENTIAL_PARAMETER_PREFIX = "blc"
|
||||
TASK_GENERATION_PREFIX = "tg"
|
||||
|
||||
|
||||
def generate_workflow_id() -> str:
|
||||
@@ -107,6 +108,11 @@ def generate_user_id() -> str:
|
||||
return f"{USER_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_task_generation_id() -> str:
|
||||
int_id = generate_id()
|
||||
return f"{TASK_GENERATION_PREFIX}_{int_id}"
|
||||
|
||||
|
||||
def generate_id() -> int:
|
||||
"""
|
||||
generate a 64-bit int ID
|
||||
|
||||
@@ -26,6 +26,7 @@ from skyvern.forge.sdk.db.id import (
|
||||
generate_organization_auth_token_id,
|
||||
generate_output_parameter_id,
|
||||
generate_step_id,
|
||||
generate_task_generation_id,
|
||||
generate_task_id,
|
||||
generate_workflow_id,
|
||||
generate_workflow_parameter_id,
|
||||
@@ -325,3 +326,27 @@ class WorkflowRunOutputParameterModel(Base):
|
||||
)
|
||||
value = Column(JSON, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class TaskGenerationModel(Base):
|
||||
"""
|
||||
Generate a task based on the prompt (natural language description of the task) from the user
|
||||
"""
|
||||
|
||||
__tablename__ = "task_generations"
|
||||
|
||||
task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id)
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False)
|
||||
user_prompt = Column(String, nullable=False, index=True) # The prompt from the user
|
||||
url = Column(String)
|
||||
navigation_goal = Column(String)
|
||||
navigation_payload = Column(JSON)
|
||||
data_extraction_goal = Column(String)
|
||||
extracted_information_schema = Column(JSON)
|
||||
|
||||
llm = Column(String) # language model to use
|
||||
llm_prompt = Column(String) # The prompt sent to the language model
|
||||
llm_response = Column(String) # The response from the language model
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
@@ -9,12 +9,15 @@ from pydantic import BaseModel
|
||||
from skyvern import analytics
|
||||
from skyvern.exceptions import StepNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||
from skyvern.forge.sdk.models import Organization, Step
|
||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
||||
from skyvern.forge.sdk.schemas.tasks import (
|
||||
CreateTaskResponse,
|
||||
ProxyLocation,
|
||||
@@ -660,3 +663,33 @@ async def get_workflow(
|
||||
organization_id=current_org.organization_id,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@base_router.post("/generate/task", include_in_schema=False)
|
||||
@base_router.post("/generate/task/")
|
||||
async def generate_task(
|
||||
data: GenerateTaskRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org_with_authentication),
|
||||
) -> TaskGeneration:
|
||||
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
|
||||
try:
|
||||
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
|
||||
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
|
||||
|
||||
# generate a TaskGenerationModel
|
||||
task_generation = await app.DATABASE.create_task_generation(
|
||||
organization_id=current_org.organization_id,
|
||||
user_prompt=data.prompt,
|
||||
url=parsed_task_generation_obj.url,
|
||||
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
||||
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
||||
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
|
||||
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
|
||||
llm=SettingsManager.get_settings().LLM_KEY,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response=str(llm_response),
|
||||
)
|
||||
return task_generation
|
||||
except LLMProviderError:
|
||||
LOG.error("Failed to generate task", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
||||
|
||||
42
skyvern/forge/sdk/schemas/task_generations.py
Normal file
42
skyvern/forge/sdk/schemas/task_generations.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class LLMType(StrEnum):
|
||||
OPENAI_GPT4O = "OPENAI_GPT4O"
|
||||
|
||||
|
||||
class TaskGenerationBase(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
organization_id: str | None = None
|
||||
user_prompt: str | None = None
|
||||
url: str | None = None
|
||||
navigation_goal: str | None = None
|
||||
navigation_payload: dict[str, Any] | None = None
|
||||
data_extraction_goal: str | None = None
|
||||
extracted_information_schema: dict[str, Any] | None = None
|
||||
llm: LLMType | None = None
|
||||
llm_prompt: str | None = None
|
||||
llm_response: str | None = None
|
||||
|
||||
|
||||
class TaskGenerationCreate(TaskGenerationBase):
|
||||
organization_id: str
|
||||
user_prompt: str
|
||||
|
||||
|
||||
class TaskGeneration(TaskGenerationBase):
|
||||
task_generation_id: str
|
||||
organization_id: str
|
||||
user_prompt: str
|
||||
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class GenerateTaskRequest(BaseModel):
|
||||
prompt: str
|
||||
Reference in New Issue
Block a user