task generation (#450)
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
@@ -9,12 +9,15 @@ from pydantic import BaseModel
|
||||
from skyvern import analytics
|
||||
from skyvern.exceptions import StepNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||
from skyvern.forge.sdk.models import Organization, Step
|
||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
||||
from skyvern.forge.sdk.schemas.tasks import (
|
||||
CreateTaskResponse,
|
||||
ProxyLocation,
|
||||
@@ -660,3 +663,33 @@ async def get_workflow(
|
||||
organization_id=current_org.organization_id,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@base_router.post("/generate/task", include_in_schema=False)
|
||||
@base_router.post("/generate/task/")
|
||||
async def generate_task(
|
||||
data: GenerateTaskRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org_with_authentication),
|
||||
) -> TaskGeneration:
|
||||
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
|
||||
try:
|
||||
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
|
||||
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
|
||||
|
||||
# generate a TaskGenerationModel
|
||||
task_generation = await app.DATABASE.create_task_generation(
|
||||
organization_id=current_org.organization_id,
|
||||
user_prompt=data.prompt,
|
||||
url=parsed_task_generation_obj.url,
|
||||
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
||||
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
||||
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
|
||||
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
|
||||
llm=SettingsManager.get_settings().LLM_KEY,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response=str(llm_response),
|
||||
)
|
||||
return task_generation
|
||||
except LLMProviderError:
|
||||
LOG.error("Failed to generate task", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
||||
|
||||
Reference in New Issue
Block a user