From 166cfb6366f9a59a41d9c78a14c07e266037cd48 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 24 Mar 2025 22:08:37 -0700 Subject: [PATCH] unified run_task api (#2012) --- skyvern/agent/client.py | 9 +- skyvern/forge/sdk/routes/agent_protocol.py | 172 ++++++++++-------- skyvern/forge/sdk/schemas/task_runs.py | 37 ---- .../forge/sdk/services/task_run_service.py | 3 +- skyvern/schemas/__init__.py | 0 skyvern/schemas/runs.py | 54 ++++++ skyvern/services/__init__.py | 0 skyvern/services/task_v1_service.py | 110 +++++++++++ 8 files changed, 259 insertions(+), 126 deletions(-) create mode 100644 skyvern/schemas/__init__.py create mode 100644 skyvern/services/__init__.py create mode 100644 skyvern/services/task_v1_service.py diff --git a/skyvern/agent/client.py b/skyvern/agent/client.py index 844597eb..0d3e4c9d 100644 --- a/skyvern/agent/client.py +++ b/skyvern/agent/client.py @@ -1,18 +1,11 @@ -from enum import StrEnum from typing import Any import httpx from skyvern.config import settings from skyvern.exceptions import SkyvernClientException -from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse, WorkflowRunResponse -from skyvern.schemas.runs import ProxyLocation - - -class RunEngine(StrEnum): - skyvern_v1 = "skyvern-1.0" - skyvern_v2 = "skyvern-2.0" +from skyvern.schemas.runs import ProxyLocation, RunEngine, TaskRunResponse class SkyvernClient: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 450f6a3d..d9034917 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1,5 +1,4 @@ import datetime -import hashlib import os import uuid from enum import Enum @@ -20,7 +19,6 @@ from fastapi import ( status, ) from fastapi.responses import ORJSONResponse -from sqlalchemy.exc import OperationalError from skyvern import analytics from skyvern.config import settings @@ -30,7 +28,6 @@ from skyvern.forge.sdk.api.aws import aws_client from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.artifact.models import Artifact from skyvern.forge.sdk.core import skyvern_context -from skyvern.forge.sdk.core.hashing import generate_url_hash from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType @@ -43,8 +40,8 @@ from skyvern.forge.sdk.schemas.organizations import ( Organization, OrganizationUpdate, ) -from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase -from skyvern.forge.sdk.schemas.task_runs import TaskRunResponse, TaskRunType +from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration +from skyvern.forge.sdk.schemas.task_runs import TaskRunType from skyvern.forge.sdk.schemas.task_v2 import TaskV2Request from skyvern.forge.sdk.schemas.tasks import ( CreateTaskResponse, @@ -74,9 +71,12 @@ from skyvern.forge.sdk.workflow.models.workflow import ( WorkflowStatus, ) from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest +from skyvern.schemas.runs import RunEngine, TaskRunRequest, TaskRunResponse, TaskRunStatus +from skyvern.services import task_v1_service from skyvern.webeye.actions.actions import Action from skyvern.webeye.schemas import BrowserSessionResponse +official_router = APIRouter() base_router = APIRouter() v2_router = APIRouter() @@ -190,31 +190,13 @@ async def run_task_v1( analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url}) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=task.browser_session_id) - created_task = await app.agent.create_task(task, current_org.organization_id) - url_hash = generate_url_hash(task.url) - await app.DATABASE.create_task_run( - task_run_type=TaskRunType.task_v1, - organization_id=current_org.organization_id, - run_id=created_task.task_id, - title=task.title, - url=task.url, - url_hash=url_hash, - ) - if x_max_steps_override: - LOG.info( - "Overriding max steps per run", - max_steps_override=x_max_steps_override, - organization_id=current_org.organization_id, - task_id=created_task.task_id, - ) - await AsyncExecutorFactory.get_executor().execute_task( + created_task = await task_v1_service.run_task( + task=task, + organization=current_org, + x_max_steps_override=x_max_steps_override, + x_api_key=x_api_key, request=request, background_tasks=background_tasks, - task_id=created_task.task_id, - organization_id=current_org.organization_id, - max_steps_override=x_max_steps_override, - browser_session_id=task.browser_session_id, - api_key=x_api_key, ) return CreateTaskResponse(task_id=created_task.task_id) @@ -1148,59 +1130,11 @@ async def generate_task( data: GenerateTaskRequest, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> TaskGeneration: - user_prompt = data.prompt - hash_object = hashlib.sha256() - hash_object.update(user_prompt.encode("utf-8")) - user_prompt_hash = hash_object.hexdigest() - # check if there's a same user_prompt within the past x Hours - # in the future, we can use vector db to fetch similar prompts - existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash( - user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS + analytics.capture("skyvern-oss-agent-generate-task") + return await task_v1_service.generate_task( + user_prompt=data.prompt, + organization=current_org, ) - if existing_task_generation: - new_task_generation = await app.DATABASE.create_task_generation( - organization_id=current_org.organization_id, - user_prompt=data.prompt, - user_prompt_hash=user_prompt_hash, - url=existing_task_generation.url, - navigation_goal=existing_task_generation.navigation_goal, - navigation_payload=existing_task_generation.navigation_payload, - data_extraction_goal=existing_task_generation.data_extraction_goal, - extracted_information_schema=existing_task_generation.extracted_information_schema, - llm=existing_task_generation.llm, - llm_prompt=existing_task_generation.llm_prompt, - llm_response=existing_task_generation.llm_response, - source_task_generation_id=existing_task_generation.task_generation_id, - ) - return new_task_generation - - llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt) - try: - llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task") - parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response) - - # generate a TaskGenerationModel - task_generation = await app.DATABASE.create_task_generation( - organization_id=current_org.organization_id, - user_prompt=data.prompt, - user_prompt_hash=user_prompt_hash, - url=parsed_task_generation_obj.url, - navigation_goal=parsed_task_generation_obj.navigation_goal, - navigation_payload=parsed_task_generation_obj.navigation_payload, - data_extraction_goal=parsed_task_generation_obj.data_extraction_goal, - extracted_information_schema=parsed_task_generation_obj.extracted_information_schema, - suggested_title=parsed_task_generation_obj.suggested_title, - llm=settings.LLM_KEY, - llm_prompt=llm_prompt, - llm_response=str(llm_response), - ) - return task_generation - except LLMProviderError: - LOG.error("Failed to generate task", exc_info=True) - raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.") - except OperationalError: - LOG.error("Database error when generating task", exc_info=True, user_prompt=data.prompt) - raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.") @base_router.put( @@ -1561,3 +1495,81 @@ async def _flatten_workflow_run_timeline(organization_id: str, workflow_run_id: final_workflow_run_block_timeline.extend(thought_timeline) final_workflow_run_block_timeline.sort(key=lambda x: x.created_at, reverse=True) return final_workflow_run_block_timeline + + +@official_router.post("/tasks") +@official_router.post("/tasks/", include_in_schema=False) +async def run_task( + request: Request, + background_tasks: BackgroundTasks, + run_request: TaskRunRequest, + current_org: Organization = Depends(org_auth_service.get_current_org), + x_api_key: Annotated[str | None, Header()] = None, +) -> TaskRunResponse: + if run_request.engine == RunEngine.skyvern_v1: + # create task v1 + # if there's no url, call task generation first to generate the url, data schema if any + url = run_request.url + data_extraction_goal = None + data_extraction_schema = run_request.data_extraction_schema + navigation_goal = run_request.goal + navigation_payload = None + if not url: + task_generation = await task_v1_service.generate_task( + user_prompt=run_request.goal, + organization=current_org, + ) + url = task_generation.url + navigation_goal = task_generation.navigation_goal or run_request.goal + navigation_payload = task_generation.navigation_payload + data_extraction_goal = task_generation.data_extraction_goal + data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema + + task_v1_request = TaskRequest( + title=run_request.title, + url=url, + navigation_goal=navigation_goal, + navigation_payload=navigation_payload, + data_extraction_goal=data_extraction_goal, + extracted_information_schema=data_extraction_schema, + error_code_mapping=run_request.error_code_mapping, + proxy_location=run_request.proxy_location, + browser_session_id=run_request.browser_session_id, + ) + task_v1_response = await task_v1_service.run_task( + task=task_v1_request, + organization=current_org, + x_max_steps_override=run_request.max_steps, + x_api_key=x_api_key, + request=request, + background_tasks=background_tasks, + ) + # build the task run response + return TaskRunResponse( + run_id=task_v1_response.task_id, + title=task_v1_response.title, + status=str(task_v1_response.status), + created_at=task_v1_response.created_at, + updated_at=task_v1_response.modified_at, + engine=RunEngine.skyvern_v1, + goal=task_v1_response.navigation_goal, + url=task_v1_response.url, + output=task_v1_response.extracted_information, + failure_reason=task_v1_response.failure_reason, + data_extraction_schema=task_v1_response.extracted_information_schema, + error_code_mapping=task_v1_response.error_code_mapping, + proxy_location=task_v1_response.proxy_location, + totp_identifier=task_v1_response.totp_identifier, + totp_url=task_v1_response.totp_verification_url, + webhook_url=task_v1_response.webhook_callback_url, + max_steps=task_v1_response.max_steps_per_run, + ) + if run_request.engine == RunEngine.skyvern_v2: + # create task v2 + raise NotImplementedError("Skyvern v2 is not implemented") + return TaskRunResponse( + run_id="run_id", + status=TaskRunStatus.queued, + created_at=datetime.datetime.now(datetime.UTC), + updated_at=datetime.datetime.now(datetime.UTC), + ) diff --git a/skyvern/forge/sdk/schemas/task_runs.py b/skyvern/forge/sdk/schemas/task_runs.py index 8273f87e..bbb344af 100644 --- a/skyvern/forge/sdk/schemas/task_runs.py +++ b/skyvern/forge/sdk/schemas/task_runs.py @@ -3,24 +3,6 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict -from skyvern.schemas.runs import ProxyLocation - - -class TaskRunStatus(StrEnum): - created = "created" - queued = "queued" - running = "running" - timed_out = "timed_out" - failed = "failed" - terminated = "terminated" - completed = "completed" - canceled = "canceled" - - -class RunEngine(StrEnum): - skyvern_v1 = "skyvern-1.0" - skyvern_v2 = "skyvern-2.0" - class TaskRunType(StrEnum): task_v1 = "task_v1" @@ -40,22 +22,3 @@ class TaskRun(BaseModel): cached: bool = False created_at: datetime modified_at: datetime - - -class TaskRunResponse(BaseModel): - run_id: str - engine: RunEngine = RunEngine.skyvern_v1 - status: TaskRunStatus - goal: str | None = None - url: str | None = None - output: dict | list | str | None = None - failure_reason: str | None = None - webhook_url: str | None = None - totp_identifier: str | None = None - totp_url: str | None = None - proxy_location: ProxyLocation | None = None - error_code_mapping: dict[str, str] | None = None - title: str | None = None - max_steps: int | None = None - created_at: datetime - modified_at: datetime diff --git a/skyvern/forge/sdk/services/task_run_service.py b/skyvern/forge/sdk/services/task_run_service.py index 64599a2d..71994314 100644 --- a/skyvern/forge/sdk/services/task_run_service.py +++ b/skyvern/forge/sdk/services/task_run_service.py @@ -1,5 +1,6 @@ from skyvern.forge import app -from skyvern.forge.sdk.schemas.task_runs import RunEngine, TaskRun, TaskRunResponse, TaskRunType +from skyvern.forge.sdk.schemas.task_runs import TaskRun, TaskRunType +from skyvern.schemas.runs import RunEngine, TaskRunResponse async def get_task_run(run_id: str, organization_id: str | None = None) -> TaskRun | None: diff --git a/skyvern/schemas/__init__.py b/skyvern/schemas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index 3bb718fb..486b3dbf 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -1,6 +1,9 @@ +from datetime import datetime from enum import StrEnum from zoneinfo import ZoneInfo +from pydantic import BaseModel + class ProxyLocation(StrEnum): US_CA = "US-CA" @@ -79,3 +82,54 @@ def get_tzinfo_from_proxy(proxy_location: ProxyLocation) -> ZoneInfo | None: return ZoneInfo("America/New_York") return None + + +class RunEngine(StrEnum): + skyvern_v1 = "skyvern-1.0" + skyvern_v2 = "skyvern-2.0" + + +class TaskRunStatus(StrEnum): + created = "created" + queued = "queued" + running = "running" + timed_out = "timed_out" + failed = "failed" + terminated = "terminated" + completed = "completed" + canceled = "canceled" + + +class TaskRunRequest(BaseModel): + goal: str + url: str | None = None + title: str | None = None + engine: RunEngine = RunEngine.skyvern_v1 + proxy_location: ProxyLocation | None = None + data_extraction_schema: dict | list | str | None = None + error_code_mapping: dict[str, str] | None = None + max_steps: int | None = None + webhook_url: str | None = None + totp_identifier: str | None = None + totp_url: str | None = None + browser_session_id: str | None = None + + +class TaskRunResponse(BaseModel): + run_id: str + engine: RunEngine = RunEngine.skyvern_v1 + status: TaskRunStatus + goal: str | None = None + url: str | None = None + output: dict | list | str | None = None + failure_reason: str | None = None + webhook_url: str | None = None + totp_identifier: str | None = None + totp_url: str | None = None + proxy_location: ProxyLocation | None = None + error_code_mapping: dict[str, str] | None = None + data_extraction_schema: dict | list | str | None = None + title: str | None = None + max_steps: int | None = None + created_at: datetime + modified_at: datetime diff --git a/skyvern/services/__init__.py b/skyvern/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skyvern/services/task_v1_service.py b/skyvern/services/task_v1_service.py new file mode 100644 index 00000000..3c6e1c66 --- /dev/null +++ b/skyvern/services/task_v1_service.py @@ -0,0 +1,110 @@ +import hashlib + +import structlog +from fastapi import BackgroundTasks, HTTPException, Request +from sqlalchemy.exc import OperationalError + +from skyvern.config import settings +from skyvern.forge import app +from skyvern.forge.prompts import prompt_engine +from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError +from skyvern.forge.sdk.core.hashing import generate_url_hash +from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory +from skyvern.forge.sdk.schemas.organizations import Organization +from skyvern.forge.sdk.schemas.task_generations import TaskGeneration, TaskGenerationBase +from skyvern.forge.sdk.schemas.task_runs import TaskRunType +from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest + +LOG = structlog.get_logger() + + +async def generate_task(user_prompt: str, organization: Organization) -> TaskGeneration: + hash_object = hashlib.sha256() + hash_object.update(user_prompt.encode("utf-8")) + user_prompt_hash = hash_object.hexdigest() + # check if there's a same user_prompt within the past x Hours + # in the future, we can use vector db to fetch similar prompts + existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash( + user_prompt_hash=user_prompt_hash, query_window_hours=settings.PROMPT_CACHE_WINDOW_HOURS + ) + if existing_task_generation: + new_task_generation = await app.DATABASE.create_task_generation( + organization_id=organization.organization_id, + user_prompt=user_prompt, + user_prompt_hash=user_prompt_hash, + url=existing_task_generation.url, + navigation_goal=existing_task_generation.navigation_goal, + navigation_payload=existing_task_generation.navigation_payload, + data_extraction_goal=existing_task_generation.data_extraction_goal, + extracted_information_schema=existing_task_generation.extracted_information_schema, + llm=existing_task_generation.llm, + llm_prompt=existing_task_generation.llm_prompt, + llm_response=existing_task_generation.llm_response, + source_task_generation_id=existing_task_generation.task_generation_id, + ) + return new_task_generation + + llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=user_prompt) + try: + llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task") + parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response) + + # generate a TaskGenerationModel + task_generation = await app.DATABASE.create_task_generation( + organization_id=organization.organization_id, + user_prompt=user_prompt, + user_prompt_hash=user_prompt_hash, + url=parsed_task_generation_obj.url, + navigation_goal=parsed_task_generation_obj.navigation_goal, + navigation_payload=parsed_task_generation_obj.navigation_payload, + data_extraction_goal=parsed_task_generation_obj.data_extraction_goal, + extracted_information_schema=parsed_task_generation_obj.extracted_information_schema, + suggested_title=parsed_task_generation_obj.suggested_title, + llm=settings.LLM_KEY, + llm_prompt=llm_prompt, + llm_response=str(llm_response), + ) + return task_generation + except LLMProviderError: + LOG.error("Failed to generate task", exc_info=True) + raise HTTPException(status_code=400, detail="Failed to generate task. Please try again later.") + except OperationalError: + LOG.error("Database error when generating task", exc_info=True, user_prompt=user_prompt) + raise HTTPException(status_code=500, detail="Failed to generate task. Please try again later.") + + +async def run_task( + task: TaskRequest, + organization: Organization, + x_max_steps_override: int | None = None, + x_api_key: str | None = None, + request: Request | None = None, + background_tasks: BackgroundTasks | None = None, +) -> Task: + created_task = await app.agent.create_task(task, organization.organization_id) + url_hash = generate_url_hash(task.url) + await app.DATABASE.create_task_run( + task_run_type=TaskRunType.task_v1, + organization_id=organization.organization_id, + run_id=created_task.task_id, + title=task.title, + url=task.url, + url_hash=url_hash, + ) + if x_max_steps_override: + LOG.info( + "Overriding max steps per run", + max_steps_override=x_max_steps_override, + organization_id=organization.organization_id, + task_id=created_task.task_id, + ) + await AsyncExecutorFactory.get_executor().execute_task( + request=request, + background_tasks=background_tasks, + task_id=created_task.task_id, + organization_id=organization.organization_id, + max_steps_override=x_max_steps_override, + browser_session_id=task.browser_session_id, + api_key=x_api_key, + ) + return created_task