cancel run API (#2067)

This commit is contained in:
Shuchang Zheng
2025-04-01 23:48:39 -04:00
committed by GitHub
parent f317b71468
commit c664cfb5a9
3 changed files with 100 additions and 7 deletions

View File

@@ -1,5 +1,11 @@
from fastapi import HTTPException, status
from skyvern.exceptions import TaskNotFound, WorkflowRunNotFound
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
from skyvern.services import task_v2_service
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
@@ -72,3 +78,69 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
# modified_at=run.modified_at,
# )
raise ValueError(f"Invalid task run type: {run.task_run_type}")
async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
task = await app.DATABASE.get_task(task_id, organization_id=organization_id)
if not task:
raise TaskNotFound(task_id=task_id)
task = await app.agent.update_task(task, status=TaskStatus.canceled)
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization_id)
await app.agent.execute_task_webhook(task=task, last_step=latest_step, api_key=api_key)
async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> None:
task_v2 = await app.DATABASE.get_task_v2(task_id, organization_id=organization_id)
if not task_v2:
raise TaskNotFound(task_id=task_id)
await task_v2_service.mark_task_v2_as_canceled(
task_v2_id=task_id, workflow_run_id=task_v2.workflow_run_id, organization_id=organization_id
)
async def cancel_workflow_run(
workflow_run_id: str, organization_id: str | None = None, api_key: str | None = None
) -> None:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
# get all the child workflow runs and cancel them
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
parent_workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
for child_workflow_run in child_workflow_runs:
if child_workflow_run.status not in [
WorkflowRunStatus.running,
WorkflowRunStatus.created,
WorkflowRunStatus.queued,
]:
continue
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id)
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=api_key)
async def cancel_run(run_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
if not run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Run not found {run_id}",
)
if run.task_run_type == RunType.task_v1:
await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key)
elif run.task_run_type == RunType.task_v2:
await cancel_task_v2(run_id, organization_id=organization_id)
elif run.task_run_type == RunType.workflow_run:
await cancel_workflow_run(run_id, organization_id=organization_id, api_key=api_key)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid run type to cancel: {run.task_run_type}",
)