unified run_task api (#2012)
This commit is contained in:
@@ -1,18 +1,11 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import SkyvernClientException
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse
|
||||
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
|
||||
|
||||
class RunEngine(StrEnum):
|
||||
skyvern_v1 = "skyvern-1.0"
|
||||
skyvern_v2 = "skyvern-2.0"
|
||||
from skyvern.schemas.runs import ProxyLocation, RunEngine, TaskRunResponse
|
||||
|
||||
|
||||
class SkyvernClient:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
@@ -20,7 +19,6 @@ from fastapi import (
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from skyvern import analytics
|
||||
from skyvern.config import settings
|
||||
@@ -30,7 +28,6 @@ from skyvern.forge.sdk.api.aws import aws_client
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||
from skyvern.forge.sdk.artifact.models import Artifact
|
||||
from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.core.hashing import generate_url_hash
|
||||
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_signature
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
@@ -43,8 +40,8 @@ from skyvern.forge.sdk.schemas.organizations import (
|
||||
Organization,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse, TaskRunType
|
||||
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
|
||||
from skyvern.forge.sdk.schemas.tasks import (
|
||||
CreateTaskResponse,
|
||||
@@ -74,9 +71,12 @@ from skyvern.forge.sdk.workflow.models.workflow import (
|
||||
WorkflowStatus,
|
||||
)
|
||||
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
||||
from skyvern.schemas.runs import RunEngine, TaskRunRequest, TaskRunResponse, TaskRunStatus
|
||||
from skyvern.services import task_v1_service
|
||||
from skyvern.webeye.actions.actions import Action
|
||||
from skyvern.webeye.schemas import BrowserSessionResponse
|
||||
|
||||
official_router = APIRouter()
|
||||
base_router = APIRouter()
|
||||
v2_router = APIRouter()
|
||||
|
||||
@@ -190,31 +190,13 @@ async def run_task_v1(
|
||||
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
|
||||
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id)
|
||||
|
||||
created_task = await app.agent.create_task(task, current_org.organization_id)
|
||||
url_hash = generate_url_hash(task.url)
|
||||
await app.DATABASE.create_task_run(
|
||||
task_run_type=TaskRunType.task_v1,
|
||||
organization_id=current_org.organization_id,
|
||||
run_id=created_task.task_id,
|
||||
title=task.title,
|
||||
url=task.url,
|
||||
url_hash=url_hash,
|
||||
)
|
||||
if x_max_steps_override:
|
||||
LOG.info(
|
||||
"Overriding max steps per run",
|
||||
max_steps_override=x_max_steps_override,
|
||||
organization_id=current_org.organization_id,
|
||||
task_id=created_task.task_id,
|
||||
)
|
||||
await AsyncExecutorFactory.get_executor().execute_task(
|
||||
created_task = await task_v1_service.run_task(
|
||||
task=task,
|
||||
organization=current_org,
|
||||
x_max_steps_override=x_max_steps_override,
|
||||
x_api_key=x_api_key,
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
task_id=created_task.task_id,
|
||||
organization_id=current_org.organization_id,
|
||||
max_steps_override=x_max_steps_override,
|
||||
browser_session_id=task.browser_session_id,
|
||||
api_key=x_api_key,
|
||||
)
|
||||
return CreateTaskResponse(task_id=created_task.task_id)
|
||||
|
||||
@@ -1148,59 +1130,11 @@ async def generate_task(
|
||||
data: GenerateTaskRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> TaskGeneration:
|
||||
user_prompt = data.prompt
|
||||
hash_object = hashlib.sha256()
|
||||
hash_object.update(user_prompt.encode("utf-8"))
|
||||
user_prompt_hash = hash_object.hexdigest()
|
||||
# check if there's a same user_prompt within the past x Hours
|
||||
# in the future, we can use vector db to fetch similar prompts
|
||||
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
|
||||
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
|
||||
analytics.capture("skyvern-oss-agent-generate-task")
|
||||
return await task_v1_service.generate_task(
|
||||
user_prompt=data.prompt,
|
||||
organization=current_org,
|
||||
)
|
||||
if existing_task_generation:
|
||||
new_task_generation = await app.DATABASE.create_task_generation(
|
||||
organization_id=current_org.organization_id,
|
||||
user_prompt=data.prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
url=existing_task_generation.url,
|
||||
navigation_goal=existing_task_generation.navigation_goal,
|
||||
navigation_payload=existing_task_generation.navigation_payload,
|
||||
data_extraction_goal=existing_task_generation.data_extraction_goal,
|
||||
extracted_information_schema=existing_task_generation.extracted_information_schema,
|
||||
llm=existing_task_generation.llm,
|
||||
llm_prompt=existing_task_generation.llm_prompt,
|
||||
llm_response=existing_task_generation.llm_response,
|
||||
source_task_generation_id=existing_task_generation.task_generation_id,
|
||||
)
|
||||
return new_task_generation
|
||||
|
||||
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
|
||||
try:
|
||||
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
|
||||
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
|
||||
|
||||
# generate a TaskGenerationModel
|
||||
task_generation = await app.DATABASE.create_task_generation(
|
||||
organization_id=current_org.organization_id,
|
||||
user_prompt=data.prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
url=parsed_task_generation_obj.url,
|
||||
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
||||
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
||||
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
|
||||
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
|
||||
suggested_title=parsed_task_generation_obj.suggested_title,
|
||||
llm=settings.LLM_KEY,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response=str(llm_response),
|
||||
)
|
||||
return task_generation
|
||||
except LLMProviderError:
|
||||
LOG.error("Failed to generate task", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
||||
except OperationalError:
|
||||
LOG.error("Database error when generating task", exc_info=True, user_prompt=data.prompt)
|
||||
raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.")
|
||||
|
||||
|
||||
@base_router.put(
|
||||
@@ -1561,3 +1495,81 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
|
||||
final_workflow_run_block_timeline.extend(thought_timeline)
|
||||
final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True)
|
||||
return final_workflow_run_block_timeline
|
||||
|
||||
|
||||
@official_router.post("/tasks")
|
||||
@official_router.post("/tasks/", include_in_schema=False)
|
||||
async def run_task(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
run_request: TaskRunRequest,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> TaskRunResponse:
|
||||
if run_request.engine == RunEngine.skyvern_v1:
|
||||
# create task v1
|
||||
# if there's no url, call task generation first to generate the url, data schema if any
|
||||
url = run_request.url
|
||||
data_extraction_goal = None
|
||||
data_extraction_schema = run_request.data_extraction_schema
|
||||
navigation_goal = run_request.goal
|
||||
navigation_payload = None
|
||||
if not url:
|
||||
task_generation = await task_v1_service.generate_task(
|
||||
user_prompt=run_request.goal,
|
||||
organization=current_org,
|
||||
)
|
||||
url = task_generation.url
|
||||
navigation_goal = task_generation.navigation_goal or run_request.goal
|
||||
navigation_payload = task_generation.navigation_payload
|
||||
data_extraction_goal = task_generation.data_extraction_goal
|
||||
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
|
||||
|
||||
task_v1_request = TaskRequest(
|
||||
title=run_request.title,
|
||||
url=url,
|
||||
navigation_goal=navigation_goal,
|
||||
navigation_payload=navigation_payload,
|
||||
data_extraction_goal=data_extraction_goal,
|
||||
extracted_information_schema=data_extraction_schema,
|
||||
error_code_mapping=run_request.error_code_mapping,
|
||||
proxy_location=run_request.proxy_location,
|
||||
browser_session_id=run_request.browser_session_id,
|
||||
)
|
||||
task_v1_response = await task_v1_service.run_task(
|
||||
task=task_v1_request,
|
||||
organization=current_org,
|
||||
x_max_steps_override=run_request.max_steps,
|
||||
x_api_key=x_api_key,
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
)
|
||||
# build the task run response
|
||||
return TaskRunResponse(
|
||||
run_id=task_v1_response.task_id,
|
||||
title=task_v1_response.title,
|
||||
status=str(task_v1_response.status),
|
||||
created_at=task_v1_response.created_at,
|
||||
updated_at=task_v1_response.modified_at,
|
||||
engine=RunEngine.skyvern_v1,
|
||||
goal=task_v1_response.navigation_goal,
|
||||
url=task_v1_response.url,
|
||||
output=task_v1_response.extracted_information,
|
||||
failure_reason=task_v1_response.failure_reason,
|
||||
data_extraction_schema=task_v1_response.extracted_information_schema,
|
||||
error_code_mapping=task_v1_response.error_code_mapping,
|
||||
proxy_location=task_v1_response.proxy_location,
|
||||
totp_identifier=task_v1_response.totp_identifier,
|
||||
totp_url=task_v1_response.totp_verification_url,
|
||||
webhook_url=task_v1_response.webhook_callback_url,
|
||||
max_steps=task_v1_response.max_steps_per_run,
|
||||
)
|
||||
if run_request.engine == RunEngine.skyvern_v2:
|
||||
# create task v2
|
||||
raise NotImplementedError("Skyvern v2 is not implemented")
|
||||
return TaskRunResponse(
|
||||
run_id="run_id",
|
||||
status=TaskRunStatus.queued,
|
||||
created_at=datetime.datetime.now(datetime.UTC),
|
||||
updated_at=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
|
||||
@@ -3,24 +3,6 @@ from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from skyvern.schemas.runs import ProxyLocation
|
||||
|
||||
|
||||
class TaskRunStatus(StrEnum):
|
||||
created = "created"
|
||||
queued = "queued"
|
||||
running = "running"
|
||||
timed_out = "timed_out"
|
||||
failed = "failed"
|
||||
terminated = "terminated"
|
||||
completed = "completed"
|
||||
canceled = "canceled"
|
||||
|
||||
|
||||
class RunEngine(StrEnum):
|
||||
skyvern_v1 = "skyvern-1.0"
|
||||
skyvern_v2 = "skyvern-2.0"
|
||||
|
||||
|
||||
class TaskRunType(StrEnum):
|
||||
task_v1 = "task_v1"
|
||||
@@ -40,22 +22,3 @@ class TaskRun(BaseModel):
|
||||
cached: bool = False
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
|
||||
class TaskRunResponse(BaseModel):
|
||||
run_id: str
|
||||
engine: RunEngine = RunEngine.skyvern_v1
|
||||
status: TaskRunStatus
|
||||
goal: str | None = None
|
||||
url: str | None = None
|
||||
output: dict | list | str | None = None
|
||||
failure_reason: str | None = None
|
||||
webhook_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
totp_url: str | None = None
|
||||
proxy_location: ProxyLocation | None = None
|
||||
error_code_mapping: dict[str, str] | None = None
|
||||
title: str | None = None
|
||||
max_steps: int | None = None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.task_runs import RunEngine, TaskRun, TaskRunResponse, TaskRunType
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType
|
||||
from skyvern.schemas.runs import RunEngine, TaskRunResponse
|
||||
|
||||
|
||||
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:
|
||||
|
||||
0
skyvern/schemas/__init__.py
Normal file
0
skyvern/schemas/__init__.py
Normal file
@@ -1,6 +1,9 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProxyLocation(StrEnum):
|
||||
US_CA = "US-CA"
|
||||
@@ -79,3 +82,54 @@ def get_tzinfo_from_proxy(proxy_location: ProxyLocation) -> ZoneInfo | None:
|
||||
return ZoneInfo("America/New_York")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RunEngine(StrEnum):
|
||||
skyvern_v1 = "skyvern-1.0"
|
||||
skyvern_v2 = "skyvern-2.0"
|
||||
|
||||
|
||||
class TaskRunStatus(StrEnum):
|
||||
created = "created"
|
||||
queued = "queued"
|
||||
running = "running"
|
||||
timed_out = "timed_out"
|
||||
failed = "failed"
|
||||
terminated = "terminated"
|
||||
completed = "completed"
|
||||
canceled = "canceled"
|
||||
|
||||
|
||||
class TaskRunRequest(BaseModel):
|
||||
goal: str
|
||||
url: str | None = None
|
||||
title: str | None = None
|
||||
engine: RunEngine = RunEngine.skyvern_v1
|
||||
proxy_location: ProxyLocation | None = None
|
||||
data_extraction_schema: dict | list | str | None = None
|
||||
error_code_mapping: dict[str, str] | None = None
|
||||
max_steps: int | None = None
|
||||
webhook_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
totp_url: str | None = None
|
||||
browser_session_id: str | None = None
|
||||
|
||||
|
||||
class TaskRunResponse(BaseModel):
|
||||
run_id: str
|
||||
engine: RunEngine = RunEngine.skyvern_v1
|
||||
status: TaskRunStatus
|
||||
goal: str | None = None
|
||||
url: str | None = None
|
||||
output: dict | list | str | None = None
|
||||
failure_reason: str | None = None
|
||||
webhook_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
totp_url: str | None = None
|
||||
proxy_location: ProxyLocation | None = None
|
||||
error_code_mapping: dict[str, str] | None = None
|
||||
data_extraction_schema: dict | list | str | None = None
|
||||
title: str | None = None
|
||||
max_steps: int | None = None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
|
||||
0
skyvern/services/__init__.py
Normal file
0
skyvern/services/__init__.py
Normal file
110
skyvern/services/task_v1_service.py
Normal file
110
skyvern/services/task_v1_service.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import hashlib
|
||||
|
||||
import structlog
|
||||
from fastapi import BackgroundTasks, HTTPException, Request
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||
from skyvern.forge.sdk.core.hashing import generate_url_hash
|
||||
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
async def generate_task(user_prompt: str, organization: Organization) -> TaskGeneration:
|
||||
hash_object = hashlib.sha256()
|
||||
hash_object.update(user_prompt.encode("utf-8"))
|
||||
user_prompt_hash = hash_object.hexdigest()
|
||||
# check if there's a same user_prompt within the past x Hours
|
||||
# in the future, we can use vector db to fetch similar prompts
|
||||
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
|
||||
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
|
||||
)
|
||||
if existing_task_generation:
|
||||
new_task_generation = await app.DATABASE.create_task_generation(
|
||||
organization_id=organization.organization_id,
|
||||
user_prompt=user_prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
url=existing_task_generation.url,
|
||||
navigation_goal=existing_task_generation.navigation_goal,
|
||||
navigation_payload=existing_task_generation.navigation_payload,
|
||||
data_extraction_goal=existing_task_generation.data_extraction_goal,
|
||||
extracted_information_schema=existing_task_generation.extracted_information_schema,
|
||||
llm=existing_task_generation.llm,
|
||||
llm_prompt=existing_task_generation.llm_prompt,
|
||||
llm_response=existing_task_generation.llm_response,
|
||||
source_task_generation_id=existing_task_generation.task_generation_id,
|
||||
)
|
||||
return new_task_generation
|
||||
|
||||
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=user_prompt)
|
||||
try:
|
||||
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
|
||||
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
|
||||
|
||||
# generate a TaskGenerationModel
|
||||
task_generation = await app.DATABASE.create_task_generation(
|
||||
organization_id=organization.organization_id,
|
||||
user_prompt=user_prompt,
|
||||
user_prompt_hash=user_prompt_hash,
|
||||
url=parsed_task_generation_obj.url,
|
||||
navigation_goal=parsed_task_generation_obj.navigation_goal,
|
||||
navigation_payload=parsed_task_generation_obj.navigation_payload,
|
||||
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
|
||||
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
|
||||
suggested_title=parsed_task_generation_obj.suggested_title,
|
||||
llm=settings.LLM_KEY,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response=str(llm_response),
|
||||
)
|
||||
return task_generation
|
||||
except LLMProviderError:
|
||||
LOG.error("Failed to generate task", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
|
||||
except OperationalError:
|
||||
LOG.error("Database error when generating task", exc_info=True, user_prompt=user_prompt)
|
||||
raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.")
|
||||
|
||||
|
||||
async def run_task(
|
||||
task: TaskRequest,
|
||||
organization: Organization,
|
||||
x_max_steps_override: int | None = None,
|
||||
x_api_key: str | None = None,
|
||||
request: Request | None = None,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> Task:
|
||||
created_task = await app.agent.create_task(task, organization.organization_id)
|
||||
url_hash = generate_url_hash(task.url)
|
||||
await app.DATABASE.create_task_run(
|
||||
task_run_type=TaskRunType.task_v1,
|
||||
organization_id=organization.organization_id,
|
||||
run_id=created_task.task_id,
|
||||
title=task.title,
|
||||
url=task.url,
|
||||
url_hash=url_hash,
|
||||
)
|
||||
if x_max_steps_override:
|
||||
LOG.info(
|
||||
"Overriding max steps per run",
|
||||
max_steps_override=x_max_steps_override,
|
||||
organization_id=organization.organization_id,
|
||||
task_id=created_task.task_id,
|
||||
)
|
||||
await AsyncExecutorFactory.get_executor().execute_task(
|
||||
request=request,
|
||||
background_tasks=background_tasks,
|
||||
task_id=created_task.task_id,
|
||||
organization_id=organization.organization_id,
|
||||
max_steps_override=x_max_steps_override,
|
||||
browser_session_id=task.browser_session_id,
|
||||
api_key=x_api_key,
|
||||
)
|
||||
return created_task
|
||||
Reference in New Issue
Block a user