cancel run API (#2067)
This commit is contained in:
@@ -1,5 +1,11 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from skyvern.exceptions import TaskNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
|
||||
from skyvern.services import task_v2_service
|
||||
|
||||
|
||||
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
|
||||
@@ -72,3 +78,69 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
||||
# modified_at=run.modified_at,
|
||||
# )
|
||||
raise ValueError(f"Invalid task run type: {run.task_run_type}")
|
||||
|
||||
|
||||
async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
|
||||
task = await app.DATABASE.get_task(task_id, organization_id=organization_id)
|
||||
if not task:
|
||||
raise TaskNotFound(task_id=task_id)
|
||||
task = await app.agent.update_task(task, status=TaskStatus.canceled)
|
||||
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization_id)
|
||||
await app.agent.execute_task_webhook(task=task, last_step=latest_step, api_key=api_key)
|
||||
|
||||
|
||||
async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> None:
|
||||
task_v2 = await app.DATABASE.get_task_v2(task_id, organization_id=organization_id)
|
||||
if not task_v2:
|
||||
raise TaskNotFound(task_id=task_id)
|
||||
await task_v2_service.mark_task_v2_as_canceled(
|
||||
task_v2_id=task_id, workflow_run_id=task_v2.workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
|
||||
async def cancel_workflow_run(
|
||||
workflow_run_id: str, organization_id: str | None = None, api_key: str | None = None
|
||||
) -> None:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
|
||||
|
||||
# get all the child workflow runs and cancel them
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
for child_workflow_run in child_workflow_runs:
|
||||
if child_workflow_run.status not in [
|
||||
WorkflowRunStatus.running,
|
||||
WorkflowRunStatus.created,
|
||||
WorkflowRunStatus.queued,
|
||||
]:
|
||||
continue
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id)
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
|
||||
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=api_key)
|
||||
|
||||
|
||||
async def cancel_run(run_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
|
||||
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Run not found {run_id}",
|
||||
)
|
||||
|
||||
if run.task_run_type == RunType.task_v1:
|
||||
await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key)
|
||||
elif run.task_run_type == RunType.task_v2:
|
||||
await cancel_task_v2(run_id, organization_id=organization_id)
|
||||
elif run.task_run_type == RunType.workflow_run:
|
||||
await cancel_workflow_run(run_id, organization_id=organization_id, api_key=api_key)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid run type to cancel: {run.task_run_type}",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user