unified run_task api (#2012)
This commit is contained in:
@@ -1,18 +1,11 @@
|
|||||||
from enum import StrEnum
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.exceptions import SkyvernClientException
|
from skyvern.exceptions import SkyvernClientException
|
||||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse
|
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse
|
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse
|
||||||
from skyvern.schemas.runs import ProxyLocation
|
from skyvern.schemas.runs import ProxyLocation, RunEngine, TaskRunResponse
|
||||||
|
|
||||||
|
|
||||||
class RunEngine(StrEnum):
|
|
||||||
skyvern_v1 = "skyvern-1.0"
|
|
||||||
skyvern_v2 = "skyvern-2.0"
|
|
||||||
|
|
||||||
|
|
||||||
class SkyvernClient:
|
class SkyvernClient:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -20,7 +19,6 @@ from fastapi import (
|
|||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from sqlalchemy.exc import OperationalError
|
|
||||||
|
|
||||||
from skyvern import analytics
|
from skyvern import analytics
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
@@ -30,7 +28,6 @@ from skyvern.forge.sdk.api.aws import aws_client
|
|||||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||||
from skyvern.forge.sdk.artifact.models import Artifact
|
from skyvern.forge.sdk.artifact.models import Artifact
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.hashing import generate_url_hash
|
|
||||||
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
||||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
@@ -43,8 +40,8 @@ from skyvern.forge.sdk.schemas.organizations import (
|
|||||||
Organization,
|
Organization,
|
||||||
OrganizationUpdate,
|
OrganizationUpdate,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration
|
||||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse, TaskRunType
|
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
||||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
||||||
from skyvern.forge.sdk.schemas.tasks import (
|
from skyvern.forge.sdk.schemas.tasks import (
|
||||||
CreateTaskResponse,
|
CreateTaskResponse,
|
||||||
@@ -74,9 +71,12 @@ from skyvern.forge.sdk.workflow.models.workflow import (
|
|||||||
WorkflowStatus,
|
WorkflowStatus,
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
||||||
|
from skyvern.schemas.runs import RunEngine, TaskRunRequest, TaskRunResponse, TaskRunStatus
|
||||||
|
from skyvern.services import task_v1_service
|
||||||
from skyvern.webeye.actions.actions import Action
|
from skyvern.webeye.actions.actions import Action
|
||||||
from skyvern.webeye.schemas import BrowserSessionResponse
|
from skyvern.webeye.schemas import BrowserSessionResponse
|
||||||
|
|
||||||
|
official_router = APIRouter()
|
||||||
base_router = APIRouter()
|
base_router = APIRouter()
|
||||||
v2_router = APIRouter()
|
v2_router = APIRouter()
|
||||||
|
|
||||||
@@ -190,31 +190,13 @@ async def run_task_v1(
|
|||||||
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
|
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
|
||||||
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id)
|
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id)
|
||||||
|
|
||||||
created_task = await app.agent.create_task(task, current_org.organization_id)
|
created_task = await task_v1_service.run_task(
|
||||||
url_hash = generate_url_hash(task.url)
|
task=task,
|
||||||
await app.DATABASE.create_task_run(
|
organization=current_org,
|
||||||
task_run_type=TaskRunType.task_v1,
|
x_max_steps_override=x_max_steps_override,
|
||||||
organization_id=current_org.organization_id,
|
x_api_key=x_api_key,
|
||||||
run_id=created_task.task_id,
|
|
||||||
title=task.title,
|
|
||||||
url=task.url,
|
|
||||||
url_hash=url_hash,
|
|
||||||
)
|
|
||||||
if x_max_steps_override:
|
|
||||||
LOG.info(
|
|
||||||
"Overriding max steps per run",
|
|
||||||
max_steps_override=x_max_steps_override,
|
|
||||||
organization_id=current_org.organization_id,
|
|
||||||
task_id=created_task.task_id,
|
|
||||||
)
|
|
||||||
await AsyncExecutorFactory.get_executor().execute_task(
|
|
||||||
request=request,
|
request=request,
|
||||||
background_tasks=background_tasks,
|
background_tasks=background_tasks,
|
||||||
task_id=created_task.task_id,
|
|
||||||
organization_id=current_org.organization_id,
|
|
||||||
max_steps_override=x_max_steps_override,
|
|
||||||
browser_session_id=task.browser_session_id,
|
|
||||||
api_key=x_api_key,
|
|
||||||
)
|
)
|
||||||
return CreateTaskResponse(task_id=created_task.task_id)
|
return CreateTaskResponse(task_id=created_task.task_id)
|
||||||
|
|
||||||
@@ -1148,59 +1130,11 @@ async def generate_task(
|
|||||||
data: GenerateTaskRequest,
|
data: GenerateTaskRequest,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> TaskGeneration:
|
) -> TaskGeneration:
|
||||||
user_prompt = data.prompt
|
analytics.capture("skyvern-oss-agent-generate-task")
|
||||||
hash_object = hashlib.sha256()
|
return await task_v1_service.generate_task(
|
||||||
hash_object.update(user_prompt.encode("utf-8"))
|
user_prompt=data.prompt,
|
||||||
user_prompt_hash = hash_object.hexdigest()
|
organization=current_org,
|
||||||
# check if there's a same user_prompt within the past x Hours
|
|
||||||
# in the future, we can use vector db to fetch similar prompts
|
|
||||||
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
|
|
||||||
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
|
|
||||||
)
|
)
|
||||||
if existing_task_generation:
|
|
||||||
new_task_generation = await app.DATABASE.create_task_generation(
|
|
||||||
organization_id=current_org.organization_id,
|
|
||||||
user_prompt=data.prompt,
|
|
||||||
user_prompt_hash=user_prompt_hash,
|
|
||||||
url=existing_task_generation.url,
|
|
||||||
navigation_goal=existing_task_generation.navigation_goal,
|
|
||||||
navigation_payload=existing_task_generation.navigation_payload,
|
|
||||||
data_extraction_goal=existing_task_generation.data_extraction_goal,
|
|
||||||
extracted_information_schema=existing_task_generation.extracted_information_schema,
|
|
||||||
llm=existing_task_generation.llm,
|
|
||||||
llm_prompt=existing_task_generation.llm_prompt,
|
|
||||||
llm_response=existing_task_generation.llm_response,
|
|
||||||
source_task_generation_id=existing_task_generation.task_generation_id,
|
|
||||||
)
|
|
||||||
return new_task_generation
|
|
||||||
|
|
||||||
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
|
|
||||||
try:
|
|
||||||
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
|
|
||||||
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
|
|
||||||
|
|
||||||
# generate a TaskGenerationModel
|
|
||||||
task_generation = await app.DATABASE.create_task_generation(
|
|
||||||
organization_id=current_org.organization_id,
|
|
||||||
user_prompt=data.prompt,
|
|
||||||
user_prompt_hash=user_prompt_hash,
|
|
||||||
url=parsed_task_generation_obj.url,
|
|
||||||
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
|
||||||
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
|
||||||
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
|
|
||||||
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
|
|
||||||
suggested_title=parsed_task_generation_obj.suggested_title,
|
|
||||||
llm=settings.LLM_KEY,
|
|
||||||
llm_prompt=llm_prompt,
|
|
||||||
llm_response=str(llm_response),
|
|
||||||
)
|
|
||||||
return task_generation
|
|
||||||
except LLMProviderError:
|
|
||||||
LOG.error("Failed to generate task", exc_info=True)
|
|
||||||
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
|
||||||
except OperationalError:
|
|
||||||
LOG.error("Database error when generating task", exc_info=True, user_prompt=data.prompt)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.")
|
|
||||||
|
|
||||||
|
|
||||||
@base_router.put(
|
@base_router.put(
|
||||||
@@ -1561,3 +1495,81 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
|
|||||||
final_workflow_run_block_timeline.extend(thought_timeline)
|
final_workflow_run_block_timeline.extend(thought_timeline)
|
||||||
final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True)
|
final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True)
|
||||||
return final_workflow_run_block_timeline
|
return final_workflow_run_block_timeline
|
||||||
|
|
||||||
|
|
||||||
|
@official_router.post("/tasks")
|
||||||
|
@official_router.post("/tasks/", include_in_schema=False)
|
||||||
|
async def run_task(
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
run_request: TaskRunRequest,
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
x_api_key: Annotated[str | None, Header()] = None,
|
||||||
|
) -> TaskRunResponse:
|
||||||
|
if run_request.engine == RunEngine.skyvern_v1:
|
||||||
|
# create task v1
|
||||||
|
# if there's no url, call task generation first to generate the url, data schema if any
|
||||||
|
url = run_request.url
|
||||||
|
data_extraction_goal = None
|
||||||
|
data_extraction_schema = run_request.data_extraction_schema
|
||||||
|
navigation_goal = run_request.goal
|
||||||
|
navigation_payload = None
|
||||||
|
if not url:
|
||||||
|
task_generation = await task_v1_service.generate_task(
|
||||||
|
user_prompt=run_request.goal,
|
||||||
|
organization=current_org,
|
||||||
|
)
|
||||||
|
url = task_generation.url
|
||||||
|
navigation_goal = task_generation.navigation_goal or run_request.goal
|
||||||
|
navigation_payload = task_generation.navigation_payload
|
||||||
|
data_extraction_goal = task_generation.data_extraction_goal
|
||||||
|
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
|
||||||
|
|
||||||
|
task_v1_request = TaskRequest(
|
||||||
|
title=run_request.title,
|
||||||
|
url=url,
|
||||||
|
navigation_goal=navigation_goal,
|
||||||
|
navigation_payload=navigation_payload,
|
||||||
|
data_extraction_goal=data_extraction_goal,
|
||||||
|
extracted_information_schema=data_extraction_schema,
|
||||||
|
error_code_mapping=run_request.error_code_mapping,
|
||||||
|
proxy_location=run_request.proxy_location,
|
||||||
|
browser_session_id=run_request.browser_session_id,
|
||||||
|
)
|
||||||
|
task_v1_response = await task_v1_service.run_task(
|
||||||
|
task=task_v1_request,
|
||||||
|
organization=current_org,
|
||||||
|
x_max_steps_override=run_request.max_steps,
|
||||||
|
x_api_key=x_api_key,
|
||||||
|
request=request,
|
||||||
|
background_tasks=background_tasks,
|
||||||
|
)
|
||||||
|
# build the task run response
|
||||||
|
return TaskRunResponse(
|
||||||
|
run_id=task_v1_response.task_id,
|
||||||
|
title=task_v1_response.title,
|
||||||
|
status=str(task_v1_response.status),
|
||||||
|
created_at=task_v1_response.created_at,
|
||||||
|
updated_at=task_v1_response.modified_at,
|
||||||
|
engine=RunEngine.skyvern_v1,
|
||||||
|
goal=task_v1_response.navigation_goal,
|
||||||
|
url=task_v1_response.url,
|
||||||
|
output=task_v1_response.extracted_information,
|
||||||
|
failure_reason=task_v1_response.failure_reason,
|
||||||
|
data_extraction_schema=task_v1_response.extracted_information_schema,
|
||||||
|
error_code_mapping=task_v1_response.error_code_mapping,
|
||||||
|
proxy_location=task_v1_response.proxy_location,
|
||||||
|
totp_identifier=task_v1_response.totp_identifier,
|
||||||
|
totp_url=task_v1_response.totp_verification_url,
|
||||||
|
webhook_url=task_v1_response.webhook_callback_url,
|
||||||
|
max_steps=task_v1_response.max_steps_per_run,
|
||||||
|
)
|
||||||
|
if run_request.engine == RunEngine.skyvern_v2:
|
||||||
|
# create task v2
|
||||||
|
raise NotImplementedError("Skyvern v2 is not implemented")
|
||||||
|
return TaskRunResponse(
|
||||||
|
run_id="run_id",
|
||||||
|
status=TaskRunStatus.queued,
|
||||||
|
created_at=datetime.datetime.now(datetime.UTC),
|
||||||
|
updated_at=datetime.datetime.now(datetime.UTC),
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,24 +3,6 @@ from enum import StrEnum
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from skyvern.schemas.runs import ProxyLocation
|
|
||||||
|
|
||||||
|
|
||||||
class TaskRunStatus(StrEnum):
|
|
||||||
created = "created"
|
|
||||||
queued = "queued"
|
|
||||||
running = "running"
|
|
||||||
timed_out = "timed_out"
|
|
||||||
failed = "failed"
|
|
||||||
terminated = "terminated"
|
|
||||||
completed = "completed"
|
|
||||||
canceled = "canceled"
|
|
||||||
|
|
||||||
|
|
||||||
class RunEngine(StrEnum):
|
|
||||||
skyvern_v1 = "skyvern-1.0"
|
|
||||||
skyvern_v2 = "skyvern-2.0"
|
|
||||||
|
|
||||||
|
|
||||||
class TaskRunType(StrEnum):
|
class TaskRunType(StrEnum):
|
||||||
task_v1 = "task_v1"
|
task_v1 = "task_v1"
|
||||||
@@ -40,22 +22,3 @@ class TaskRun(BaseModel):
|
|||||||
cached: bool = False
|
cached: bool = False
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
modified_at: datetime
|
modified_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class TaskRunResponse(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
engine: RunEngine = RunEngine.skyvern_v1
|
|
||||||
status: TaskRunStatus
|
|
||||||
goal: str | None = None
|
|
||||||
url: str | None = None
|
|
||||||
output: dict | list | str | None = None
|
|
||||||
failure_reason: str | None = None
|
|
||||||
webhook_url: str | None = None
|
|
||||||
totp_identifier: str | None = None
|
|
||||||
totp_url: str | None = None
|
|
||||||
proxy_location: ProxyLocation | None = None
|
|
||||||
error_code_mapping: dict[str, str] | None = None
|
|
||||||
title: str | None = None
|
|
||||||
max_steps: int | None = None
|
|
||||||
created_at: datetime
|
|
||||||
modified_at: datetime
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
from skyvern.forge.sdk.schemas.task_runs import RunEngine, TaskRun, TaskRunResponse, TaskRunType
|
from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType
|
||||||
|
from skyvern.schemas.runs import RunEngine, TaskRunResponse
|
||||||
|
|
||||||
|
|
||||||
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:
|
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:
|
||||||
|
|||||||
0
skyvern/schemas/__init__.py
Normal file
0
skyvern/schemas/__init__.py
Normal file
@@ -1,6 +1,9 @@
|
|||||||
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ProxyLocation(StrEnum):
|
class ProxyLocation(StrEnum):
|
||||||
US_CA = "US-CA"
|
US_CA = "US-CA"
|
||||||
@@ -79,3 +82,54 @@ def get_tzinfo_from_proxy(proxy_location: ProxyLocation) -> ZoneInfo | None:
|
|||||||
return ZoneInfo("America/New_York")
|
return ZoneInfo("America/New_York")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RunEngine(StrEnum):
|
||||||
|
skyvern_v1 = "skyvern-1.0"
|
||||||
|
skyvern_v2 = "skyvern-2.0"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRunStatus(StrEnum):
|
||||||
|
created = "created"
|
||||||
|
queued = "queued"
|
||||||
|
running = "running"
|
||||||
|
timed_out = "timed_out"
|
||||||
|
failed = "failed"
|
||||||
|
terminated = "terminated"
|
||||||
|
completed = "completed"
|
||||||
|
canceled = "canceled"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRunRequest(BaseModel):
|
||||||
|
goal: str
|
||||||
|
url: str | None = None
|
||||||
|
title: str | None = None
|
||||||
|
engine: RunEngine = RunEngine.skyvern_v1
|
||||||
|
proxy_location: ProxyLocation | None = None
|
||||||
|
data_extraction_schema: dict | list | str | None = None
|
||||||
|
error_code_mapping: dict[str, str] | None = None
|
||||||
|
max_steps: int | None = None
|
||||||
|
webhook_url: str | None = None
|
||||||
|
totp_identifier: str | None = None
|
||||||
|
totp_url: str | None = None
|
||||||
|
browser_session_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRunResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
engine: RunEngine = RunEngine.skyvern_v1
|
||||||
|
status: TaskRunStatus
|
||||||
|
goal: str | None = None
|
||||||
|
url: str | None = None
|
||||||
|
output: dict | list | str | None = None
|
||||||
|
failure_reason: str | None = None
|
||||||
|
webhook_url: str | None = None
|
||||||
|
totp_identifier: str | None = None
|
||||||
|
totp_url: str | None = None
|
||||||
|
proxy_location: ProxyLocation | None = None
|
||||||
|
error_code_mapping: dict[str, str] | None = None
|
||||||
|
data_extraction_schema: dict | list | str | None = None
|
||||||
|
title: str | None = None
|
||||||
|
max_steps: int | None = None
|
||||||
|
created_at: datetime
|
||||||
|
modified_at: datetime
|
||||||
|
|||||||
0
skyvern/services/__init__.py
Normal file
0
skyvern/services/__init__.py
Normal file
110
skyvern/services/task_v1_service.py
Normal file
110
skyvern/services/task_v1_service.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import BackgroundTasks, HTTPException, Request
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
from skyvern.config import settings
|
||||||
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.prompts import prompt_engine
|
||||||
|
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||||
|
from skyvern.forge.sdk.core.hashing import generate_url_hash
|
||||||
|
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||||
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||||
|
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase
|
||||||
|
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
||||||
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest
|
||||||
|
|
||||||
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_task(user_prompt: str, organization: Organization) -> TaskGeneration:
|
||||||
|
hash_object = hashlib.sha256()
|
||||||
|
hash_object.update(user_prompt.encode("utf-8"))
|
||||||
|
user_prompt_hash = hash_object.hexdigest()
|
||||||
|
# check if there's a same user_prompt within the past x Hours
|
||||||
|
# in the future, we can use vector db to fetch similar prompts
|
||||||
|
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
|
||||||
|
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
|
||||||
|
)
|
||||||
|
if existing_task_generation:
|
||||||
|
new_task_generation = await app.DATABASE.create_task_generation(
|
||||||
|
organization_id=organization.organization_id,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
user_prompt_hash=user_prompt_hash,
|
||||||
|
url=existing_task_generation.url,
|
||||||
|
navigation_goal=existing_task_generation.navigation_goal,
|
||||||
|
navigation_payload=existing_task_generation.navigation_payload,
|
||||||
|
data_extraction_goal=existing_task_generation.data_extraction_goal,
|
||||||
|
extracted_information_schema=existing_task_generation.extracted_information_schema,
|
||||||
|
llm=existing_task_generation.llm,
|
||||||
|
llm_prompt=existing_task_generation.llm_prompt,
|
||||||
|
llm_response=existing_task_generation.llm_response,
|
||||||
|
source_task_generation_id=existing_task_generation.task_generation_id,
|
||||||
|
)
|
||||||
|
return new_task_generation
|
||||||
|
|
||||||
|
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=user_prompt)
|
||||||
|
try:
|
||||||
|
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
|
||||||
|
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
|
||||||
|
|
||||||
|
# generate a TaskGenerationModel
|
||||||
|
task_generation = await app.DATABASE.create_task_generation(
|
||||||
|
organization_id=organization.organization_id,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
user_prompt_hash=user_prompt_hash,
|
||||||
|
url=parsed_task_generation_obj.url,
|
||||||
|
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
||||||
|
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
||||||
|
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
|
||||||
|
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
|
||||||
|
suggested_title=parsed_task_generation_obj.suggested_title,
|
||||||
|
llm=settings.LLM_KEY,
|
||||||
|
llm_prompt=llm_prompt,
|
||||||
|
llm_response=str(llm_response),
|
||||||
|
)
|
||||||
|
return task_generation
|
||||||
|
except LLMProviderError:
|
||||||
|
LOG.error("Failed to generate task", exc_info=True)
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
||||||
|
except OperationalError:
|
||||||
|
LOG.error("Database error when generating task", exc_info=True, user_prompt=user_prompt)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_task(
|
||||||
|
task: TaskRequest,
|
||||||
|
organization: Organization,
|
||||||
|
x_max_steps_override: int | None = None,
|
||||||
|
x_api_key: str | None = None,
|
||||||
|
request: Request | None = None,
|
||||||
|
background_tasks: BackgroundTasks | None = None,
|
||||||
|
) -> Task:
|
||||||
|
created_task = await app.agent.create_task(task, organization.organization_id)
|
||||||
|
url_hash = generate_url_hash(task.url)
|
||||||
|
await app.DATABASE.create_task_run(
|
||||||
|
task_run_type=TaskRunType.task_v1,
|
||||||
|
organization_id=organization.organization_id,
|
||||||
|
run_id=created_task.task_id,
|
||||||
|
title=task.title,
|
||||||
|
url=task.url,
|
||||||
|
url_hash=url_hash,
|
||||||
|
)
|
||||||
|
if x_max_steps_override:
|
||||||
|
LOG.info(
|
||||||
|
"Overriding max steps per run",
|
||||||
|
max_steps_override=x_max_steps_override,
|
||||||
|
organization_id=organization.organization_id,
|
||||||
|
task_id=created_task.task_id,
|
||||||
|
)
|
||||||
|
await AsyncExecutorFactory.get_executor().execute_task(
|
||||||
|
request=request,
|
||||||
|
background_tasks=background_tasks,
|
||||||
|
task_id=created_task.task_id,
|
||||||
|
organization_id=organization.organization_id,
|
||||||
|
max_steps_override=x_max_steps_override,
|
||||||
|
browser_session_id=task.browser_session_id,
|
||||||
|
api_key=x_api_key,
|
||||||
|
)
|
||||||
|
return created_task
|
||||||
Reference in New Issue
Block a user