From c664cfb5a9e70c87946d08141a0cdc0941313d73 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 1 Apr 2025 23:48:39 -0400 Subject: [PATCH] cancel run API (#2067) --- skyvern/forge/sdk/db/client.py | 2 +- skyvern/forge/sdk/routes/agent_protocol.py | 33 ++++++++-- skyvern/services/run_service.py | 72 ++++++++++++++++++++++ 3 files changed, 100 insertions(+), 7 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index d4e5e40f..91c3123a 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1547,8 +1547,8 @@ class AgentDB: async def get_workflow_runs_by_parent_workflow_run_id( self, - organization_id: str, parent_workflow_run_id: str, + organization_id: str | None = None, ) -> list[WorkflowRun]: try: async with self.Session() as session: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 9cee3a5a..c3aecc94 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -6,7 +6,7 @@ from typing import Annotated, Any import structlog import yaml -from fastapi import BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, UploadFile, status +from fastapi import BackgroundTasks, Depends, Header, HTTPException, Path, Query, Request, Response, UploadFile, status from fastapi.responses import ORJSONResponse from skyvern import analytics @@ -295,8 +295,8 @@ async def cancel_workflow_run( ) # get all the child workflow runs and cancel them child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id( - organization_id=current_org.organization_id, parent_workflow_run_id=workflow_run_id, + organization_id=current_org.organization_id, ) for child_workflow_run in child_workflow_runs: if child_workflow_run.status not in [ @@ -456,7 +456,7 @@ async def get_runs( include_in_schema=False, ) async def get_run( - run_id: str, + run_id: str = Path(..., description="The id of the task run or the workflow run."), current_org: Organization = Depends(org_auth_service.get_current_org), ) -> RunResponse: run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id) @@ -1468,15 +1468,15 @@ async def run_task( url = run_request.url data_extraction_goal = None data_extraction_schema = run_request.data_extraction_schema - navigation_goal = run_request.goal + navigation_goal = run_request.prompt navigation_payload = None if not url: task_generation = await task_v1_service.generate_task( - user_prompt=run_request.goal, + user_prompt=run_request.prompt, organization=current_org, ) url = task_generation.url - navigation_goal = task_generation.navigation_goal or run_request.goal + navigation_goal = task_generation.navigation_goal or run_request.prompt navigation_payload = task_generation.navigation_payload data_extraction_goal = task_generation.data_extraction_goal data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema @@ -1642,3 +1642,24 @@ async def run_workflow( downloaded_files=None, recording_url=None, ) + + +@base_router.post( + "/runs/{run_id}/cancel", + tags=["Agent"], + openapi_extra={ + "x-fern-sdk-group-name": "agent", + "x-fern-sdk-method-name": "cancel_run", + }, + description="Cancel a task or workflow run", + summary="Cancel a task or workflow run", +) +@base_router.post("/runs/{run_id}/cancel/", include_in_schema=False) +async def cancel_run( + run_id: str = Path(..., description="The id of the task run or the workflow run to cancel."), + current_org: Organization = Depends(org_auth_service.get_current_org), + x_api_key: Annotated[str | None, Header()] = None, +) -> None: + analytics.capture("skyvern-oss-agent-cancel-run") + + await run_service.cancel_run(run_id, organization_id=current_org.organization_id, api_key=x_api_key) diff --git a/skyvern/services/run_service.py b/skyvern/services/run_service.py index e8061fa5..bcbd4002 100644 --- a/skyvern/services/run_service.py +++ b/skyvern/services/run_service.py @@ -1,5 +1,11 @@ +from fastapi import HTTPException, status + +from skyvern.exceptions import TaskNotFound, WorkflowRunNotFound from skyvern.forge import app +from skyvern.forge.sdk.schemas.tasks import TaskStatus +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse +from skyvern.services import task_v2_service async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None: @@ -72,3 +78,69 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R # modified_at=run.modified_at, # ) raise ValueError(f"Invalid task run type: {run.task_run_type}") + + +async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_key: str | None = None) -> None: + task = await app.DATABASE.get_task(task_id, organization_id=organization_id) + if not task: + raise TaskNotFound(task_id=task_id) + task = await app.agent.update_task(task, status=TaskStatus.canceled) + latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization_id) + await app.agent.execute_task_webhook(task=task, last_step=latest_step, api_key=api_key) + + +async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> None: + task_v2 = await app.DATABASE.get_task_v2(task_id, organization_id=organization_id) + if not task_v2: + raise TaskNotFound(task_id=task_id) + await task_v2_service.mark_task_v2_as_canceled( + task_v2_id=task_id, workflow_run_id=task_v2.workflow_run_id, organization_id=organization_id + ) + + +async def cancel_workflow_run( + workflow_run_id: str, organization_id: str | None = None, api_key: str | None = None +) -> None: + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + if not workflow_run: + raise WorkflowRunNotFound(workflow_run_id=workflow_run_id) + + # get all the child workflow runs and cancel them + child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id( + parent_workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + for child_workflow_run in child_workflow_runs: + if child_workflow_run.status not in [ + WorkflowRunStatus.running, + WorkflowRunStatus.created, + WorkflowRunStatus.queued, + ]: + continue + await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id) + await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id) + await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=api_key) + + +async def cancel_run(run_id: str, organization_id: str | None = None, api_key: str | None = None) -> None: + run = await app.DATABASE.get_run(run_id, organization_id=organization_id) + if not run: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Run not found {run_id}", + ) + + if run.task_run_type == RunType.task_v1: + await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key) + elif run.task_run_type == RunType.task_v2: + await cancel_task_v2(run_id, organization_id=organization_id) + elif run.task_run_type == RunType.workflow_run: + await cancel_workflow_run(run_id, organization_id=organization_id, api_key=api_key) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid run type to cancel: {run.task_run_type}", + )