cancel run API (#2067)

This commit is contained in:
Shuchang Zheng
2025-04-01 23:48:39 -04:00
committed by GitHub
parent f317b71468
commit c664cfb5a9
3 changed files with 100 additions and 7 deletions

View File

@@ -1547,8 +1547,8 @@ class AgentDB:
async def get_workflow_runs_by_parent_workflow_run_id(
self,
organization_id: str,
parent_workflow_run_id: str,
organization_id: str | None = None,
) -> list[WorkflowRun]:
try:
async with self.Session() as session:

View File

@@ -6,7 +6,7 @@ from typing import Annotated, Any
import structlog
import yaml
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, UploadFile, status
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Path, Query, Request, Response, UploadFile, status
from fastapi.responses import ORJSONResponse
from skyvern import analytics
@@ -295,8 +295,8 @@ async def cancel_workflow_run(
)
# get all the child workflow runs and cancel them
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
organization_id=current_org.organization_id,
parent_workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
)
for child_workflow_run in child_workflow_runs:
if child_workflow_run.status not in [
@@ -456,7 +456,7 @@ async def get_runs(
include_in_schema=False,
)
async def get_run(
run_id: str,
run_id: str = Path(..., description="The id of the task run or the workflow run."),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> RunResponse:
run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id)
@@ -1468,15 +1468,15 @@ async def run_task(
url = run_request.url
data_extraction_goal = None
data_extraction_schema = run_request.data_extraction_schema
navigation_goal = run_request.goal
navigation_goal = run_request.prompt
navigation_payload = None
if not url:
task_generation = await task_v1_service.generate_task(
user_prompt=run_request.goal,
user_prompt=run_request.prompt,
organization=current_org,
)
url = task_generation.url
navigation_goal = task_generation.navigation_goal or run_request.goal
navigation_goal = task_generation.navigation_goal or run_request.prompt
navigation_payload = task_generation.navigation_payload
data_extraction_goal = task_generation.data_extraction_goal
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
@@ -1642,3 +1642,24 @@ async def run_workflow(
downloaded_files=None,
recording_url=None,
)
@base_router.post(
"/runs/{run_id}/cancel",
tags=["Agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "cancel_run",
},
description="Cancel a task or workflow run",
summary="Cancel a task or workflow run",
)
@base_router.post("/runs/{run_id}/cancel/", include_in_schema=False)
async def cancel_run(
run_id: str = Path(..., description="The id of the task run or the workflow run to cancel."),
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
) -> None:
analytics.capture("skyvern-oss-agent-cancel-run")
await run_service.cancel_run(run_id, organization_id=current_org.organization_id, api_key=x_api_key)

View File

@@ -1,5 +1,11 @@
from fastapi import HTTPException, status
from skyvern.exceptions import TaskNotFound, WorkflowRunNotFound
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
from skyvern.services import task_v2_service
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
@@ -72,3 +78,69 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
# modified_at=run.modified_at,
# )
raise ValueError(f"Invalid task run type: {run.task_run_type}")
async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
task = await app.DATABASE.get_task(task_id, organization_id=organization_id)
if not task:
raise TaskNotFound(task_id=task_id)
task = await app.agent.update_task(task, status=TaskStatus.canceled)
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization_id)
await app.agent.execute_task_webhook(task=task, last_step=latest_step, api_key=api_key)
async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> None:
task_v2 = await app.DATABASE.get_task_v2(task_id, organization_id=organization_id)
if not task_v2:
raise TaskNotFound(task_id=task_id)
await task_v2_service.mark_task_v2_as_canceled(
task_v2_id=task_id, workflow_run_id=task_v2.workflow_run_id, organization_id=organization_id
)
async def cancel_workflow_run(
workflow_run_id: str, organization_id: str | None = None, api_key: str | None = None
) -> None:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
# get all the child workflow runs and cancel them
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
parent_workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
for child_workflow_run in child_workflow_runs:
if child_workflow_run.status not in [
WorkflowRunStatus.running,
WorkflowRunStatus.created,
WorkflowRunStatus.queued,
]:
continue
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id)
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=api_key)
async def cancel_run(run_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
if not run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Run not found {run_id}",
)
if run.task_run_type == RunType.task_v1:
await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key)
elif run.task_run_type == RunType.task_v2:
await cancel_task_v2(run_id, organization_id=organization_id)
elif run.task_run_type == RunType.workflow_run:
await cancel_workflow_run(run_id, organization_id=organization_id, api_key=api_key)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid run type to cancel: {run.task_run_type}",
)