unified run_task api (#2012)

This commit is contained in:
Shuchang Zheng
2025-03-24 22:08:37 -07:00
committed by GitHub
parent 19c7c56af7
commit 166cfb6366
8 changed files with 259 additions and 126 deletions

View File

@@ -1,18 +1,11 @@
from enum import StrEnum
from typing import Any from typing import Any
import httpx import httpx
from skyvern.config import settings from skyvern.config import settings
from skyvern.exceptions import SkyvernClientException from skyvern.exceptions import SkyvernClientException
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse
from skyvern.schemas.runs import ProxyLocation from skyvern.schemas.runs import ProxyLocation, RunEngine, TaskRunResponse
class RunEngine(StrEnum):
skyvern_v1 = "skyvern-1.0"
skyvern_v2 = "skyvern-2.0"
class SkyvernClient: class SkyvernClient:

View File

@@ -1,5 +1,4 @@
import datetime import datetime
import hashlib
import os import os
import uuid import uuid
from enum import Enum from enum import Enum
@@ -20,7 +19,6 @@ from fastapi import (
status, status,
) )
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from sqlalchemy.exc import OperationalError
from skyvern import analytics from skyvern import analytics
from skyvern.config import settings from skyvern.config import settings
@@ -30,7 +28,6 @@ from skyvern.forge.sdk.api.aws import aws_client
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.artifact.models import Artifact from skyvern.forge.sdk.artifact.models import Artifact
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.hashing import generate_url_hash
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
@@ -43,8 +40,8 @@ from skyvern.forge.sdk.schemas.organizations import (
Organization, Organization,
OrganizationUpdate, OrganizationUpdate,
) )
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse, TaskRunType from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
from skyvern.forge.sdk.schemas.tasks import ( from skyvern.forge.sdk.schemas.tasks import (
CreateTaskResponse, CreateTaskResponse,
@@ -74,9 +71,12 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowStatus, WorkflowStatus,
) )
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import RunEngine, TaskRunRequest, TaskRunResponse, TaskRunStatus
from skyvern.services import task_v1_service
from skyvern.webeye.actions.actions import Action from skyvern.webeye.actions.actions import Action
from skyvern.webeye.schemas import BrowserSessionResponse from skyvern.webeye.schemas import BrowserSessionResponse
official_router = APIRouter()
base_router = APIRouter() base_router = APIRouter()
v2_router = APIRouter() v2_router = APIRouter()
@@ -190,31 +190,13 @@ async def run_task_v1(
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url}) analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id)
created_task = await app.agent.create_task(task, current_org.organization_id) created_task = await task_v1_service.run_task(
url_hash = generate_url_hash(task.url) task=task,
await app.DATABASE.create_task_run( organization=current_org,
task_run_type=TaskRunType.task_v1, x_max_steps_override=x_max_steps_override,
organization_id=current_org.organization_id, x_api_key=x_api_key,
run_id=created_task.task_id,
title=task.title,
url=task.url,
url_hash=url_hash,
)
if x_max_steps_override:
LOG.info(
"Overriding max steps per run",
max_steps_override=x_max_steps_override,
organization_id=current_org.organization_id,
task_id=created_task.task_id,
)
await AsyncExecutorFactory.get_executor().execute_task(
request=request, request=request,
background_tasks=background_tasks, background_tasks=background_tasks,
task_id=created_task.task_id,
organization_id=current_org.organization_id,
max_steps_override=x_max_steps_override,
browser_session_id=task.browser_session_id,
api_key=x_api_key,
) )
return CreateTaskResponse(task_id=created_task.task_id) return CreateTaskResponse(task_id=created_task.task_id)
@@ -1148,59 +1130,11 @@ async def generate_task(
data: GenerateTaskRequest, data: GenerateTaskRequest,
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
) -> TaskGeneration: ) -> TaskGeneration:
user_prompt = data.prompt analytics.capture("skyvern-oss-agent-generate-task")
hash_object = hashlib.sha256() return await task_v1_service.generate_task(
hash_object.update(user_prompt.encode("utf-8")) user_prompt=data.prompt,
user_prompt_hash = hash_object.hexdigest() organization=current_org,
# check if there's a same user_prompt within the past x Hours
# in the future, we can use vector db to fetch similar prompts
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
) )
if existing_task_generation:
new_task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id,
user_prompt=data.prompt,
user_prompt_hash=user_prompt_hash,
url=existing_task_generation.url,
navigation_goal=existing_task_generation.navigation_goal,
navigation_payload=existing_task_generation.navigation_payload,
data_extraction_goal=existing_task_generation.data_extraction_goal,
extracted_information_schema=existing_task_generation.extracted_information_schema,
llm=existing_task_generation.llm,
llm_prompt=existing_task_generation.llm_prompt,
llm_response=existing_task_generation.llm_response,
source_task_generation_id=existing_task_generation.task_generation_id,
)
return new_task_generation
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
try:
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
# generate a TaskGenerationModel
task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id,
user_prompt=data.prompt,
user_prompt_hash=user_prompt_hash,
url=parsed_task_generation_obj.url,
navigation_goal=parsed_task_generation_obj.navigation_goal,
navigation_payload=parsed_task_generation_obj.navigation_payload,
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
suggested_title=parsed_task_generation_obj.suggested_title,
llm=settings.LLM_KEY,
llm_prompt=llm_prompt,
llm_response=str(llm_response),
)
return task_generation
except LLMProviderError:
LOG.error("Failed to generate task", exc_info=True)
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
except OperationalError:
LOG.error("Database error when generating task", exc_info=True, user_prompt=data.prompt)
raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.")
@base_router.put( @base_router.put(
@@ -1561,3 +1495,81 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
final_workflow_run_block_timeline.extend(thought_timeline) final_workflow_run_block_timeline.extend(thought_timeline)
final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True) final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True)
return final_workflow_run_block_timeline return final_workflow_run_block_timeline
@official_router.post("/tasks")
@official_router.post("/tasks/", include_in_schema=False)
async def run_task(
request: Request,
background_tasks: BackgroundTasks,
run_request: TaskRunRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
) -> TaskRunResponse:
if run_request.engine == RunEngine.skyvern_v1:
# create task v1
# if there's no url, call task generation first to generate the url, data schema if any
url = run_request.url
data_extraction_goal = None
data_extraction_schema = run_request.data_extraction_schema
navigation_goal = run_request.goal
navigation_payload = None
if not url:
task_generation = await task_v1_service.generate_task(
user_prompt=run_request.goal,
organization=current_org,
)
url = task_generation.url
navigation_goal = task_generation.navigation_goal or run_request.goal
navigation_payload = task_generation.navigation_payload
data_extraction_goal = task_generation.data_extraction_goal
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
task_v1_request = TaskRequest(
title=run_request.title,
url=url,
navigation_goal=navigation_goal,
navigation_payload=navigation_payload,
data_extraction_goal=data_extraction_goal,
extracted_information_schema=data_extraction_schema,
error_code_mapping=run_request.error_code_mapping,
proxy_location=run_request.proxy_location,
browser_session_id=run_request.browser_session_id,
)
task_v1_response = await task_v1_service.run_task(
task=task_v1_request,
organization=current_org,
x_max_steps_override=run_request.max_steps,
x_api_key=x_api_key,
request=request,
background_tasks=background_tasks,
)
# build the task run response
return TaskRunResponse(
run_id=task_v1_response.task_id,
title=task_v1_response.title,
status=str(task_v1_response.status),
created_at=task_v1_response.created_at,
updated_at=task_v1_response.modified_at,
engine=RunEngine.skyvern_v1,
goal=task_v1_response.navigation_goal,
url=task_v1_response.url,
output=task_v1_response.extracted_information,
failure_reason=task_v1_response.failure_reason,
data_extraction_schema=task_v1_response.extracted_information_schema,
error_code_mapping=task_v1_response.error_code_mapping,
proxy_location=task_v1_response.proxy_location,
totp_identifier=task_v1_response.totp_identifier,
totp_url=task_v1_response.totp_verification_url,
webhook_url=task_v1_response.webhook_callback_url,
max_steps=task_v1_response.max_steps_per_run,
)
if run_request.engine == RunEngine.skyvern_v2:
# create task v2
raise NotImplementedError("Skyvern v2 is not implemented")
return TaskRunResponse(
run_id="run_id",
status=TaskRunStatus.queued,
created_at=datetime.datetime.now(datetime.UTC),
updated_at=datetime.datetime.now(datetime.UTC),
)

View File

@@ -3,24 +3,6 @@ from enum import StrEnum
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from skyvern.schemas.runs import ProxyLocation
class TaskRunStatus(StrEnum):
created = "created"
queued = "queued"
running = "running"
timed_out = "timed_out"
failed = "failed"
terminated = "terminated"
completed = "completed"
canceled = "canceled"
class RunEngine(StrEnum):
skyvern_v1 = "skyvern-1.0"
skyvern_v2 = "skyvern-2.0"
class TaskRunType(StrEnum): class TaskRunType(StrEnum):
task_v1 = "task_v1" task_v1 = "task_v1"
@@ -40,22 +22,3 @@ class TaskRun(BaseModel):
cached: bool = False cached: bool = False
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime
class TaskRunResponse(BaseModel):
run_id: str
engine: RunEngine = RunEngine.skyvern_v1
status: TaskRunStatus
goal: str | None = None
url: str | None = None
output: dict | list | str | None = None
failure_reason: str | None = None
webhook_url: str | None = None
totp_identifier: str | None = None
totp_url: str | None = None
proxy_location: ProxyLocation | None = None
error_code_mapping: dict[str, str] | None = None
title: str | None = None
max_steps: int | None = None
created_at: datetime
modified_at: datetime

View File

@@ -1,5 +1,6 @@
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.sdk.schemas.task_runs import RunEngine, TaskRun, TaskRunResponse, TaskRunType from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType
from skyvern.schemas.runs import RunEngine, TaskRunResponse
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None: async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:

View File

View File

@@ -1,6 +1,9 @@
from datetime import datetime
from enum import StrEnum from enum import StrEnum
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from pydantic import BaseModel
class ProxyLocation(StrEnum): class ProxyLocation(StrEnum):
US_CA = "US-CA" US_CA = "US-CA"
@@ -79,3 +82,54 @@ def get_tzinfo_from_proxy(proxy_location: ProxyLocation) -> ZoneInfo | None:
return ZoneInfo("America/New_York") return ZoneInfo("America/New_York")
return None return None
class RunEngine(StrEnum):
skyvern_v1 = "skyvern-1.0"
skyvern_v2 = "skyvern-2.0"
class TaskRunStatus(StrEnum):
created = "created"
queued = "queued"
running = "running"
timed_out = "timed_out"
failed = "failed"
terminated = "terminated"
completed = "completed"
canceled = "canceled"
class TaskRunRequest(BaseModel):
goal: str
url: str | None = None
title: str | None = None
engine: RunEngine = RunEngine.skyvern_v1
proxy_location: ProxyLocation | None = None
data_extraction_schema: dict | list | str | None = None
error_code_mapping: dict[str, str] | None = None
max_steps: int | None = None
webhook_url: str | None = None
totp_identifier: str | None = None
totp_url: str | None = None
browser_session_id: str | None = None
class TaskRunResponse(BaseModel):
run_id: str
engine: RunEngine = RunEngine.skyvern_v1
status: TaskRunStatus
goal: str | None = None
url: str | None = None
output: dict | list | str | None = None
failure_reason: str | None = None
webhook_url: str | None = None
totp_identifier: str | None = None
totp_url: str | None = None
proxy_location: ProxyLocation | None = None
error_code_mapping: dict[str, str] | None = None
data_extraction_schema: dict | list | str | None = None
title: str | None = None
max_steps: int | None = None
created_at: datetime
modified_at: datetime

View File

View File

@@ -0,0 +1,110 @@
import hashlib
import structlog
from fastapi import BackgroundTasks, HTTPException, Request
from sqlalchemy.exc import OperationalError
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.core.hashing import generate_url_hash
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest
LOG = structlog.get_logger()
async def generate_task(user_prompt: str, organization: Organization) -> TaskGeneration:
hash_object = hashlib.sha256()
hash_object.update(user_prompt.encode("utf-8"))
user_prompt_hash = hash_object.hexdigest()
# check if there's a same user_prompt within the past x Hours
# in the future, we can use vector db to fetch similar prompts
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
)
if existing_task_generation:
new_task_generation = await app.DATABASE.create_task_generation(
organization_id=organization.organization_id,
user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
url=existing_task_generation.url,
navigation_goal=existing_task_generation.navigation_goal,
navigation_payload=existing_task_generation.navigation_payload,
data_extraction_goal=existing_task_generation.data_extraction_goal,
extracted_information_schema=existing_task_generation.extracted_information_schema,
llm=existing_task_generation.llm,
llm_prompt=existing_task_generation.llm_prompt,
llm_response=existing_task_generation.llm_response,
source_task_generation_id=existing_task_generation.task_generation_id,
)
return new_task_generation
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=user_prompt)
try:
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
# generate a TaskGenerationModel
task_generation = await app.DATABASE.create_task_generation(
organization_id=organization.organization_id,
user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
url=parsed_task_generation_obj.url,
navigation_goal=parsed_task_generation_obj.navigation_goal,
navigation_payload=parsed_task_generation_obj.navigation_payload,
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
suggested_title=parsed_task_generation_obj.suggested_title,
llm=settings.LLM_KEY,
llm_prompt=llm_prompt,
llm_response=str(llm_response),
)
return task_generation
except LLMProviderError:
LOG.error("Failed to generate task", exc_info=True)
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
except OperationalError:
LOG.error("Database error when generating task", exc_info=True, user_prompt=user_prompt)
raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.")
async def run_task(
task: TaskRequest,
organization: Organization,
x_max_steps_override: int | None = None,
x_api_key: str | None = None,
request: Request | None = None,
background_tasks: BackgroundTasks | None = None,
) -> Task:
created_task = await app.agent.create_task(task, organization.organization_id)
url_hash = generate_url_hash(task.url)
await app.DATABASE.create_task_run(
task_run_type=TaskRunType.task_v1,
organization_id=organization.organization_id,
run_id=created_task.task_id,
title=task.title,
url=task.url,
url_hash=url_hash,
)
if x_max_steps_override:
LOG.info(
"Overriding max steps per run",
max_steps_override=x_max_steps_override,
organization_id=organization.organization_id,
task_id=created_task.task_id,
)
await AsyncExecutorFactory.get_executor().execute_task(
request=request,
background_tasks=background_tasks,
task_id=created_task.task_id,
organization_id=organization.organization_id,
max_steps_override=x_max_steps_override,
browser_session_id=task.browser_session_id,
api_key=x_api_key,
)
return created_task