From ea675021617dcf2404faaad77fd2f75ec40051b1 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sun, 30 Mar 2025 21:14:52 -0700 Subject: [PATCH] shu/test discriminated run response schema (#2046) --- skyvern/agent/client.py | 17 ++-- skyvern/forge/sdk/db/client.py | 2 +- skyvern/forge/sdk/routes/agent_protocol.py | 78 ++++++++++--------- .../forge/sdk/services/task_run_service.py | 54 ------------- skyvern/schemas/runs.py | 48 ++++++++---- skyvern/services/run_service.py | 71 +++++++++++++++++ 6 files changed, 158 insertions(+), 112 deletions(-) delete mode 100644 skyvern/forge/sdk/services/task_run_service.py create mode 100644 skyvern/services/run_service.py diff --git a/skyvern/agent/client.py b/skyvern/agent/client.py index d1adaf71..3d98742b 100644 --- a/skyvern/agent/client.py +++ b/skyvern/agent/client.py @@ -1,7 +1,6 @@ from skyvern.client.client import AsyncSkyvern from skyvern.config import settings -from skyvern.forge.sdk.workflow.models.workflow import RunWorkflowResponse -from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse +from skyvern.schemas.runs import ProxyLocation, RunEngine, RunResponse, RunType, TaskRunResponse, WorkflowRunResponse class SkyvernClient: @@ -28,7 +27,7 @@ class SkyvernClient: max_steps: int | None = None, browser_session_id: str | None = None, publish_workflow: bool = False, - ) -> RunResponse: + ) -> TaskRunResponse: task_run_obj = await self.client.agent.run_task( goal=prompt, url=url, @@ -43,7 +42,7 @@ class SkyvernClient: browser_session_id=browser_session_id, publish_workflow=publish_workflow, ) - return RunResponse.model_validate(task_run_obj) + return TaskRunResponse.model_validate(task_run_obj) async def run_workflow( self, @@ -55,7 +54,7 @@ class SkyvernClient: totp_url: str | None = None, browser_session_id: str | None = None, template: bool = False, - ) -> RunWorkflowResponse: + ) -> WorkflowRunResponse: workflow_run_obj = await self.client.agent.run_workflow( workflow_id=workflow_id, data=workflow_input, @@ -66,11 +65,15 @@ class SkyvernClient: browser_session_id=browser_session_id, template=template, ) - return RunWorkflowResponse.model_validate(workflow_run_obj) + return WorkflowRunResponse.model_validate(workflow_run_obj) async def get_run( self, run_id: str, ) -> RunResponse: run_obj = await self.client.agent.get_run(run_id=run_id) - return RunResponse.model_validate(run_obj) + if run_obj.run_type in [RunType.task_v1, RunType.task_v2]: + return TaskRunResponse.model_validate(run_obj) + elif run_obj.run_type == RunType.workflow_run: + return WorkflowRunResponse.model_validate(run_obj) + raise ValueError(f"Invalid run type: {run_obj.run_type}") diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 6158fa9b..d4e5e40f 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -2945,7 +2945,7 @@ class AgentDB: task_run = (await session.scalars(query)).first() return Run.model_validate(task_run) if task_run else None - async def get_task_run( + async def get_run( self, run_id: str, organization_id: str | None = None, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 7766aa7f..34a06d4e 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -52,7 +52,7 @@ from skyvern.forge.sdk.schemas.tasks import ( TaskStatus, ) from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline -from skyvern.forge.sdk.services import org_auth_service, task_run_service +from skyvern.forge.sdk.services import org_auth_service from skyvern.forge.sdk.workflow.exceptions import ( FailedToCreateWorkflow, FailedToUpdateWorkflow, @@ -70,8 +70,8 @@ from skyvern.forge.sdk.workflow.models.workflow import ( WorkflowStatus, ) from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest -from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest -from skyvern.services import task_v1_service, task_v2_service +from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse +from skyvern.services import run_service, task_v1_service, task_v2_service from skyvern.webeye.actions.actions import Action from skyvern.webeye.schemas import BrowserSessionResponse @@ -441,7 +441,7 @@ async def get_runs( return ORJSONResponse([run.model_dump() for run in runs]) -@base_router.get( +@official_api_router.get( "/runs/{run_id}", tags=["agent"], response_model=RunResponse, @@ -450,7 +450,7 @@ async def get_runs( "x-fern-sdk-method-name": "get_run", }, ) -@base_router.get( +@official_api_router.get( "/runs/{run_id}/", response_model=RunResponse, include_in_schema=False, @@ -459,15 +459,13 @@ async def get_run( run_id: str, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> RunResponse: - task_run_response = await task_run_service.get_task_run_response( - run_id, organization_id=current_org.organization_id - ) - if not task_run_response: + run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id) + if not run_response: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Task run not found {run_id}", ) - return task_run_response + return run_response @base_router.get( @@ -1511,7 +1509,7 @@ async def run_task( run_request: TaskRunRequest, current_org: Organization = Depends(org_auth_service.get_current_org), x_api_key: Annotated[str | None, Header()] = None, -) -> RunResponse: +) -> TaskRunResponse: analytics.capture("skyvern-oss-run-task", data={"url": run_request.url}) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id) @@ -1554,24 +1552,27 @@ async def run_task( background_tasks=background_tasks, ) # build the task run response - return RunResponse( + return TaskRunResponse( run_id=task_v1_response.task_id, - title=task_v1_response.title, + run_type=RunType.task_v1, status=str(task_v1_response.status), - created_at=task_v1_response.created_at, - updated_at=task_v1_response.modified_at, - engine=RunEngine.skyvern_v1, - goal=task_v1_response.navigation_goal, - url=task_v1_response.url, output=task_v1_response.extracted_information, failure_reason=task_v1_response.failure_reason, - data_extraction_schema=task_v1_response.extracted_information_schema, - error_code_mapping=task_v1_response.error_code_mapping, - proxy_location=task_v1_response.proxy_location, - totp_identifier=task_v1_response.totp_identifier, - totp_url=task_v1_response.totp_verification_url, - webhook_url=task_v1_response.webhook_callback_url, - max_steps=task_v1_response.max_steps_per_run, + created_at=task_v1_response.created_at, + modified_at=task_v1_response.modified_at, + run_request=TaskRunRequest( + engine=RunEngine.skyvern_v1, + prompt=task_v1_response.navigation_goal, + url=task_v1_response.url, + webhook_url=task_v1_response.webhook_callback_url, + totp_identifier=task_v1_response.totp_identifier, + totp_url=task_v1_response.totp_verification_url, + proxy_location=task_v1_response.proxy_location, + max_steps=task_v1_response.max_steps_per_run, + data_extraction_schema=task_v1_response.extracted_information_schema, + error_code_mapping=task_v1_response.error_code_mapping, + browser_session_id=run_request.browser_session_id, + ), ) if run_request.engine == RunEngine.skyvern_v2: # create task v2 @@ -1602,22 +1603,27 @@ async def run_task( max_steps_override=run_request.max_steps, browser_session_id=run_request.browser_session_id, ) - return RunResponse( + return TaskRunResponse( run_id=task_v2.observer_cruise_id, - title=run_request.title, + run_type=RunType.task_v2, status=str(task_v2.status), - engine=RunEngine.skyvern_v2, - goal=task_v2.prompt, - url=task_v2.url, output=task_v2.output, failure_reason=task_v2.failure_reason, - webhook_url=task_v2.webhook_callback_url, - totp_identifier=task_v2.totp_identifier, - totp_url=task_v2.totp_verification_url, - proxy_location=task_v2.proxy_location, - error_code_mapping=task_v2.error_code_mapping, - data_extraction_schema=task_v2.extracted_information_schema, created_at=task_v2.created_at, modified_at=task_v2.modified_at, + run_request=TaskRunRequest( + engine=RunEngine.skyvern_v2, + prompt=task_v2.prompt, + url=task_v2.url, + webhook_url=task_v2.webhook_callback_url, + totp_identifier=task_v2.totp_identifier, + totp_url=task_v2.totp_verification_url, + proxy_location=task_v2.proxy_location, + max_steps=run_request.max_steps, + browser_session_id=run_request.browser_session_id, + error_code_mapping=task_v2.error_code_mapping, + data_extraction_schema=task_v2.extracted_information_schema, + publish_workflow=run_request.publish_workflow, + ), ) raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}") diff --git a/skyvern/forge/sdk/services/task_run_service.py b/skyvern/forge/sdk/services/task_run_service.py deleted file mode 100644 index e216c3d0..00000000 --- a/skyvern/forge/sdk/services/task_run_service.py +++ /dev/null @@ -1,54 +0,0 @@ -from skyvern.forge import app -from skyvern.forge.sdk.schemas.runs import Run -from skyvern.schemas.runs import RunEngine, RunResponse, RunType - - -async def get_task_run(run_id: str, organization_id: str | None = None) -> Run | None: - return await app.DATABASE.get_task_run(run_id, organization_id=organization_id) - - -async def get_task_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None: - task_run = await get_task_run(run_id, organization_id=organization_id) - if not task_run: - return None - - if task_run.task_run_type == RunType.task_v1: - # fetch task v1 from db and transform to task run response - task_v1 = await app.DATABASE.get_task(task_run.task_v1_id, organization_id=organization_id) - if not task_v1: - return None - return RunResponse( - run_id=task_run.run_id, - engine=RunEngine.skyvern_v1, - status=task_v1.status, - goal=task_v1.navigation_goal, - url=task_v1.url, - output=task_v1.extracted_information, - failure_reason=task_v1.failure_reason, - webhook_url=task_v1.webhook_callback_url, - totp_identifier=task_v1.totp_identifier, - totp_url=task_v1.totp_verification_url, - proxy_location=task_v1.proxy_location, - created_at=task_v1.created_at, - modified_at=task_v1.modified_at, - ) - elif task_run.task_run_type == RunType.task_v2: - task_v2 = await app.DATABASE.get_task_v2(task_run.task_v2_id, organization_id=organization_id) - if not task_v2: - return None - return RunResponse( - run_id=task_run.run_id, - engine=RunEngine.skyvern_v2, - status=task_v2.status, - goal=task_v2.prompt, - url=task_v2.url, - output=task_v2.output, - failure_reason=task_v2.failure_reason, - webhook_url=task_v2.webhook_callback_url, - totp_identifier=task_v2.totp_identifier, - totp_url=task_v2.totp_verification_url, - proxy_location=task_v2.proxy_location, - created_at=task_v2.created_at, - modified_at=task_v2.modified_at, - ) - raise ValueError(f"Invalid task run type: {task_run.task_run_type}") diff --git a/skyvern/schemas/runs.py b/skyvern/schemas/runs.py index 6906937b..1817fa3e 100644 --- a/skyvern/schemas/runs.py +++ b/skyvern/schemas/runs.py @@ -1,8 +1,9 @@ from datetime import datetime from enum import StrEnum +from typing import Annotated, Any, Literal, Union from zoneinfo import ZoneInfo -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from skyvern.utils.url_validators import validate_url @@ -112,7 +113,7 @@ class RunStatus(StrEnum): class TaskRunRequest(BaseModel): - goal: str + prompt: str url: str | None = None title: str | None = None engine: RunEngine = RunEngine.skyvern_v2 @@ -135,21 +136,40 @@ class TaskRunRequest(BaseModel): return validate_url(url) -class RunResponse(BaseModel): +class WorkflowRunRequest(BaseModel): + title: str | None = None + parameters: dict[str, Any] | None = None + proxy_location: ProxyLocation | None = None + webhook_url: str | None = None + totp_url: str | None = None + totp_identifier: str | None = None + browser_session_id: str | None = None + + @field_validator("webhook_url", "totp_url") + @classmethod + def validate_urls(cls, url: str | None) -> str | None: + if url is None: + return None + return validate_url(url) + + +class BaseRunResponse(BaseModel): run_id: str - engine: RunEngine = RunEngine.skyvern_v1 status: RunStatus - goal: str | None = None - url: str | None = None output: dict | list | str | None = None failure_reason: str | None = None - webhook_url: str | None = None - totp_identifier: str | None = None - totp_url: str | None = None - proxy_location: ProxyLocation | None = None - error_code_mapping: dict[str, str] | None = None - data_extraction_schema: dict | list | str | None = None - title: str | None = None - max_steps: int | None = None created_at: datetime modified_at: datetime + + +class TaskRunResponse(BaseRunResponse): + run_type: Literal[RunType.task_v1, RunType.task_v2] + run_request: TaskRunRequest | None = None + + +class WorkflowRunResponse(BaseRunResponse): + run_type: Literal[RunType.workflow_run] + run_request: WorkflowRunRequest | None = None + + +RunResponse = Annotated[Union[TaskRunResponse, WorkflowRunResponse], Field(discriminator="run_type")] diff --git a/skyvern/services/run_service.py b/skyvern/services/run_service.py new file mode 100644 index 00000000..fb464ebd --- /dev/null +++ b/skyvern/services/run_service.py @@ -0,0 +1,71 @@ +from skyvern.forge import app +from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse + + +async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None: + run = await app.DATABASE.get_run(run_id, organization_id=organization_id) + if not run: + return None + + if run.task_run_type == RunType.task_v1: + # fetch task v1 from db and transform to task run response + task_v1 = await app.DATABASE.get_task(run.task_v1_id, organization_id=organization_id) + if not task_v1: + return None + return TaskRunResponse( + run_id=run.run_id, + run_type=run.task_run_type, + status=str(task_v1.status), + output=task_v1.extracted_information, + failure_reason=task_v1.failure_reason, + created_at=task_v1.created_at, + modified_at=task_v1.modified_at, + run_request=TaskRunRequest( + engine=RunEngine.skyvern_v1, + prompt=task_v1.navigation_goal, + url=task_v1.url, + webhook_url=task_v1.webhook_callback_url, + totp_identifier=task_v1.totp_identifier, + totp_url=task_v1.totp_verification_url, + proxy_location=task_v1.proxy_location, + max_steps=task_v1.max_steps_per_run, + data_extraction_schema=task_v1.extracted_information_schema, + error_code_mapping=task_v1.error_code_mapping, + ), + ) + elif run.task_run_type == RunType.task_v2: + task_v2 = await app.DATABASE.get_task_v2(run.task_v2_id, organization_id=organization_id) + if not task_v2: + return None + return TaskRunResponse( + run_id=run.run_id, + run_type=run.task_run_type, + status=task_v2.status, + output=task_v2.output, + failure_reason=task_v2.failure_reason, + created_at=task_v2.created_at, + modified_at=task_v2.modified_at, + run_request=TaskRunRequest( + engine=RunEngine.skyvern_v2, + prompt=task_v2.prompt, + url=task_v2.url, + webhook_url=task_v2.webhook_callback_url, + totp_identifier=task_v2.totp_identifier, + totp_url=task_v2.totp_verification_url, + proxy_location=task_v2.proxy_location, + data_extraction_schema=task_v2.data_extraction_schema, + error_code_mapping=task_v2.error_code_mapping, + ), + ) + elif run.task_run_type == RunType.workflow_run: + raise NotImplementedError("Workflow run response not implemented") + # return WorkflowRunResponse( + # run_id=run.run_id, + # run_type=run.task_run_type, + # status=run.status, + # output=run.output, + # parameters=None, + # created_at=run.created_at, + # modified_at=run.modified_at, + # ) + raise ValueError(f"Invalid task run type: {run.task_run_type}")