add pbs-based cancel endpoint for workflow runs (#2913)
This commit is contained in:
@@ -1068,6 +1068,37 @@ async def cancel_task(
|
||||
await app.agent.execute_task_webhook(task=task, last_step=latest_step, api_key=x_api_key)
|
||||
|
||||
|
||||
async def _cancel_workflow_run(workflow_run_id: str, organization_id: str, x_api_key: str | None = None) -> None:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Workflow run not found {workflow_run_id}",
|
||||
)
|
||||
|
||||
# get all the child workflow runs and cancel them
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
for child_workflow_run in child_workflow_runs:
|
||||
if child_workflow_run.status not in [
|
||||
WorkflowRunStatus.running,
|
||||
WorkflowRunStatus.created,
|
||||
WorkflowRunStatus.queued,
|
||||
]:
|
||||
continue
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id)
|
||||
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
|
||||
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=x_api_key)
|
||||
|
||||
|
||||
@legacy_base_router.post(
|
||||
"/workflows/runs/{workflow_run_id}/cancel",
|
||||
tags=["agent"],
|
||||
@@ -1082,30 +1113,27 @@ async def cancel_workflow_run(
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> None:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
if not workflow_run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Workflow run not found {workflow_run_id}",
|
||||
)
|
||||
# get all the child workflow runs and cancel them
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
for child_workflow_run in child_workflow_runs:
|
||||
if child_workflow_run.status not in [
|
||||
WorkflowRunStatus.running,
|
||||
WorkflowRunStatus.created,
|
||||
WorkflowRunStatus.queued,
|
||||
]:
|
||||
continue
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id)
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
|
||||
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=x_api_key)
|
||||
await _cancel_workflow_run(workflow_run_id, current_org.organization_id, x_api_key)
|
||||
|
||||
|
||||
@legacy_base_router.post(
|
||||
"/runs/{browser_session_id}/workflow_run/{workflow_run_id}/cancel/",
|
||||
tags=["agent"],
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "agent",
|
||||
"x-fern-sdk-method-name": "cancel_workflow_run",
|
||||
},
|
||||
)
|
||||
@legacy_base_router.post("/runs/{browser_session_id}/workflow_run/{workflow_run_id}/cancel/", include_in_schema=False)
|
||||
async def cancel_persistent_browser_session_workflow_run(
|
||||
workflow_run_id: str,
|
||||
browser_session_id: str,
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> None:
|
||||
await app.PERSISTENT_SESSIONS_MANAGER.release_browser_session(browser_session_id, current_org.organization_id)
|
||||
|
||||
await _cancel_workflow_run(workflow_run_id, current_org.organization_id, x_api_key)
|
||||
|
||||
|
||||
@legacy_base_router.post(
|
||||
|
||||
Reference in New Issue
Block a user