unified run_task api (#2012)

This commit is contained in:
Shuchang Zheng
2025-03-24 22:08:37 -07:00
committed by GitHub
parent 19c7c56af7
commit 166cfb6366
8 changed files with 259 additions and 126 deletions

View File

@@ -1,18 +1,11 @@
from enum import StrEnum
from typing import Any
import httpx
from skyvern.config import settings
from skyvern.exceptions import SkyvernClientException
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse
from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse
from skyvern.schemas.runs import ProxyLocation
class RunEngine(StrEnum):
skyvern_v1 = "skyvern-1.0"
skyvern_v2 = "skyvern-2.0"
from skyvern.schemas.runs import ProxyLocation, RunEngine, TaskRunResponse
class SkyvernClient:

View File

@@ -1,5 +1,4 @@
import datetime
import hashlib
import os
import uuid
from enum import Enum
@@ -20,7 +19,6 @@ from fastapi import (
status,
)
from fastapi.responses import ORJSONResponse
from sqlalchemy.exc import OperationalError
from skyvern import analytics
from skyvern.config import settings
@@ -30,7 +28,6 @@ from skyvern.forge.sdk.api.aws import aws_client
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.artifact.models import Artifact
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.hashing import generate_url_hash
from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory
from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
@@ -43,8 +40,8 @@ from skyvern.forge.sdk.schemas.organizations import (
Organization,
OrganizationUpdate,
)
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse, TaskRunType
from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request
from skyvern.forge.sdk.schemas.tasks import (
CreateTaskResponse,
@@ -74,9 +71,12 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowStatus,
)
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import RunEngine, TaskRunRequest, TaskRunResponse, TaskRunStatus
from skyvern.services import task_v1_service
from skyvern.webeye.actions.actions import Action
from skyvern.webeye.schemas import BrowserSessionResponse
official_router = APIRouter()
base_router = APIRouter()
v2_router = APIRouter()
@@ -190,31 +190,13 @@ async def run_task_v1(
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id)
created_task = await app.agent.create_task(task, current_org.organization_id)
url_hash = generate_url_hash(task.url)
await app.DATABASE.create_task_run(
task_run_type=TaskRunType.task_v1,
organization_id=current_org.organization_id,
run_id=created_task.task_id,
title=task.title,
url=task.url,
url_hash=url_hash,
)
if x_max_steps_override:
LOG.info(
"Overriding max steps per run",
max_steps_override=x_max_steps_override,
organization_id=current_org.organization_id,
task_id=created_task.task_id,
)
await AsyncExecutorFactory.get_executor().execute_task(
created_task = await task_v1_service.run_task(
task=task,
organization=current_org,
x_max_steps_override=x_max_steps_override,
x_api_key=x_api_key,
request=request,
background_tasks=background_tasks,
task_id=created_task.task_id,
organization_id=current_org.organization_id,
max_steps_override=x_max_steps_override,
browser_session_id=task.browser_session_id,
api_key=x_api_key,
)
return CreateTaskResponse(task_id=created_task.task_id)
@@ -1148,59 +1130,11 @@ async def generate_task(
data: GenerateTaskRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> TaskGeneration:
user_prompt = data.prompt
hash_object = hashlib.sha256()
hash_object.update(user_prompt.encode("utf-8"))
user_prompt_hash = hash_object.hexdigest()
# check if there's a same user_prompt within the past x Hours
# in the future, we can use vector db to fetch similar prompts
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
analytics.capture("skyvern-oss-agent-generate-task")
return await task_v1_service.generate_task(
user_prompt=data.prompt,
organization=current_org,
)
if existing_task_generation:
new_task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id,
user_prompt=data.prompt,
user_prompt_hash=user_prompt_hash,
url=existing_task_generation.url,
navigation_goal=existing_task_generation.navigation_goal,
navigation_payload=existing_task_generation.navigation_payload,
data_extraction_goal=existing_task_generation.data_extraction_goal,
extracted_information_schema=existing_task_generation.extracted_information_schema,
llm=existing_task_generation.llm,
llm_prompt=existing_task_generation.llm_prompt,
llm_response=existing_task_generation.llm_response,
source_task_generation_id=existing_task_generation.task_generation_id,
)
return new_task_generation
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
try:
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
# generate a TaskGenerationModel
task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id,
user_prompt=data.prompt,
user_prompt_hash=user_prompt_hash,
url=parsed_task_generation_obj.url,
navigation_goal=parsed_task_generation_obj.navigation_goal,
navigation_payload=parsed_task_generation_obj.navigation_payload,
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
suggested_title=parsed_task_generation_obj.suggested_title,
llm=settings.LLM_KEY,
llm_prompt=llm_prompt,
llm_response=str(llm_response),
)
return task_generation
except LLMProviderError:
LOG.error("Failed to generate task", exc_info=True)
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
except OperationalError:
LOG.error("Database error when generating task", exc_info=True, user_prompt=data.prompt)
raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.")
@base_router.put(
@@ -1561,3 +1495,81 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id:
final_workflow_run_block_timeline.extend(thought_timeline)
final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True)
return final_workflow_run_block_timeline
@official_router.post("/tasks")
@official_router.post("/tasks/", include_in_schema=False)
async def run_task(
request: Request,
background_tasks: BackgroundTasks,
run_request: TaskRunRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
) -> TaskRunResponse:
if run_request.engine == RunEngine.skyvern_v1:
# create task v1
# if there's no url, call task generation first to generate the url, data schema if any
url = run_request.url
data_extraction_goal = None
data_extraction_schema = run_request.data_extraction_schema
navigation_goal = run_request.goal
navigation_payload = None
if not url:
task_generation = await task_v1_service.generate_task(
user_prompt=run_request.goal,
organization=current_org,
)
url = task_generation.url
navigation_goal = task_generation.navigation_goal or run_request.goal
navigation_payload = task_generation.navigation_payload
data_extraction_goal = task_generation.data_extraction_goal
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
task_v1_request = TaskRequest(
title=run_request.title,
url=url,
navigation_goal=navigation_goal,
navigation_payload=navigation_payload,
data_extraction_goal=data_extraction_goal,
extracted_information_schema=data_extraction_schema,
error_code_mapping=run_request.error_code_mapping,
proxy_location=run_request.proxy_location,
browser_session_id=run_request.browser_session_id,
)
task_v1_response = await task_v1_service.run_task(
task=task_v1_request,
organization=current_org,
x_max_steps_override=run_request.max_steps,
x_api_key=x_api_key,
request=request,
background_tasks=background_tasks,
)
# build the task run response
return TaskRunResponse(
run_id=task_v1_response.task_id,
title=task_v1_response.title,
status=str(task_v1_response.status),
created_at=task_v1_response.created_at,
updated_at=task_v1_response.modified_at,
engine=RunEngine.skyvern_v1,
goal=task_v1_response.navigation_goal,
url=task_v1_response.url,
output=task_v1_response.extracted_information,
failure_reason=task_v1_response.failure_reason,
data_extraction_schema=task_v1_response.extracted_information_schema,
error_code_mapping=task_v1_response.error_code_mapping,
proxy_location=task_v1_response.proxy_location,
totp_identifier=task_v1_response.totp_identifier,
totp_url=task_v1_response.totp_verification_url,
webhook_url=task_v1_response.webhook_callback_url,
max_steps=task_v1_response.max_steps_per_run,
)
if run_request.engine == RunEngine.skyvern_v2:
# create task v2
raise NotImplementedError("Skyvern v2 is not implemented")
return TaskRunResponse(
run_id="run_id",
status=TaskRunStatus.queued,
created_at=datetime.datetime.now(datetime.UTC),
updated_at=datetime.datetime.now(datetime.UTC),
)

View File

@@ -3,24 +3,6 @@ from enum import StrEnum
from pydantic import BaseModel, ConfigDict
from skyvern.schemas.runs import ProxyLocation
class TaskRunStatus(StrEnum):
created = "created"
queued = "queued"
running = "running"
timed_out = "timed_out"
failed = "failed"
terminated = "terminated"
completed = "completed"
canceled = "canceled"
class RunEngine(StrEnum):
skyvern_v1 = "skyvern-1.0"
skyvern_v2 = "skyvern-2.0"
class TaskRunType(StrEnum):
task_v1 = "task_v1"
@@ -40,22 +22,3 @@ class TaskRun(BaseModel):
cached: bool = False
created_at: datetime
modified_at: datetime
class TaskRunResponse(BaseModel):
run_id: str
engine: RunEngine = RunEngine.skyvern_v1
status: TaskRunStatus
goal: str | None = None
url: str | None = None
output: dict | list | str | None = None
failure_reason: str | None = None
webhook_url: str | None = None
totp_identifier: str | None = None
totp_url: str | None = None
proxy_location: ProxyLocation | None = None
error_code_mapping: dict[str, str] | None = None
title: str | None = None
max_steps: int | None = None
created_at: datetime
modified_at: datetime

View File

@@ -1,5 +1,6 @@
from skyvern.forge import app
from skyvern.forge.sdk.schemas.task_runs import RunEngine, TaskRun, TaskRunResponse, TaskRunType
from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType
from skyvern.schemas.runs import RunEngine, TaskRunResponse
async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None:

View File

View File

@@ -1,6 +1,9 @@
from datetime import datetime
from enum import StrEnum
from zoneinfo import ZoneInfo
from pydantic import BaseModel
class ProxyLocation(StrEnum):
US_CA = "US-CA"
@@ -79,3 +82,54 @@ def get_tzinfo_from_proxy(proxy_location: ProxyLocation) -> ZoneInfo | None:
return ZoneInfo("America/New_York")
return None
class RunEngine(StrEnum):
skyvern_v1 = "skyvern-1.0"
skyvern_v2 = "skyvern-2.0"
class TaskRunStatus(StrEnum):
created = "created"
queued = "queued"
running = "running"
timed_out = "timed_out"
failed = "failed"
terminated = "terminated"
completed = "completed"
canceled = "canceled"
class TaskRunRequest(BaseModel):
goal: str
url: str | None = None
title: str | None = None
engine: RunEngine = RunEngine.skyvern_v1
proxy_location: ProxyLocation | None = None
data_extraction_schema: dict | list | str | None = None
error_code_mapping: dict[str, str] | None = None
max_steps: int | None = None
webhook_url: str | None = None
totp_identifier: str | None = None
totp_url: str | None = None
browser_session_id: str | None = None
class TaskRunResponse(BaseModel):
run_id: str
engine: RunEngine = RunEngine.skyvern_v1
status: TaskRunStatus
goal: str | None = None
url: str | None = None
output: dict | list | str | None = None
failure_reason: str | None = None
webhook_url: str | None = None
totp_identifier: str | None = None
totp_url: str | None = None
proxy_location: ProxyLocation | None = None
error_code_mapping: dict[str, str] | None = None
data_extraction_schema: dict | list | str | None = None
title: str | None = None
max_steps: int | None = None
created_at: datetime
modified_at: datetime

View File

View File

@@ -0,0 +1,110 @@
import hashlib
import structlog
from fastapi import BackgroundTasks, HTTPException, Request
from sqlalchemy.exc import OperationalError
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.core.hashing import generate_url_hash
from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest
LOG = structlog.get_logger()
async def generate_task(user_prompt: str, organization: Organization) -> TaskGeneration:
hash_object = hashlib.sha256()
hash_object.update(user_prompt.encode("utf-8"))
user_prompt_hash = hash_object.hexdigest()
# check if there's a same user_prompt within the past x Hours
# in the future, we can use vector db to fetch similar prompts
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS
)
if existing_task_generation:
new_task_generation = await app.DATABASE.create_task_generation(
organization_id=organization.organization_id,
user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
url=existing_task_generation.url,
navigation_goal=existing_task_generation.navigation_goal,
navigation_payload=existing_task_generation.navigation_payload,
data_extraction_goal=existing_task_generation.data_extraction_goal,
extracted_information_schema=existing_task_generation.extracted_information_schema,
llm=existing_task_generation.llm,
llm_prompt=existing_task_generation.llm_prompt,
llm_response=existing_task_generation.llm_response,
source_task_generation_id=existing_task_generation.task_generation_id,
)
return new_task_generation
llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=user_prompt)
try:
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)
# generate a TaskGenerationModel
task_generation = await app.DATABASE.create_task_generation(
organization_id=organization.organization_id,
user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
url=parsed_task_generation_obj.url,
navigation_goal=parsed_task_generation_obj.navigation_goal,
navigation_payload=parsed_task_generation_obj.navigation_payload,
data_extraction_goal=parsed_task_generation_obj.data_extraction_goal,
extracted_information_schema=parsed_task_generation_obj.extracted_information_schema,
suggested_title=parsed_task_generation_obj.suggested_title,
llm=settings.LLM_KEY,
llm_prompt=llm_prompt,
llm_response=str(llm_response),
)
return task_generation
except LLMProviderError:
LOG.error("Failed to generate task", exc_info=True)
raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.")
except OperationalError:
LOG.error("Database error when generating task", exc_info=True, user_prompt=user_prompt)
raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.")
async def run_task(
task: TaskRequest,
organization: Organization,
x_max_steps_override: int | None = None,
x_api_key: str | None = None,
request: Request | None = None,
background_tasks: BackgroundTasks | None = None,
) -> Task:
created_task = await app.agent.create_task(task, organization.organization_id)
url_hash = generate_url_hash(task.url)
await app.DATABASE.create_task_run(
task_run_type=TaskRunType.task_v1,
organization_id=organization.organization_id,
run_id=created_task.task_id,
title=task.title,
url=task.url,
url_hash=url_hash,
)
if x_max_steps_override:
LOG.info(
"Overriding max steps per run",
max_steps_override=x_max_steps_override,
organization_id=organization.organization_id,
task_id=created_task.task_id,
)
await AsyncExecutorFactory.get_executor().execute_task(
request=request,
background_tasks=background_tasks,
task_id=created_task.task_id,
organization_id=organization.organization_id,
max_steps_override=x_max_steps_override,
browser_session_id=task.browser_session_id,
api_key=x_api_key,
)
return created_task