cancel run API (#2067)
This commit is contained in:
@@ -1547,8 +1547,8 @@ class AgentDB:
|
||||
|
||||
async def get_workflow_runs_by_parent_workflow_run_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
parent_workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Annotated, Any
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, UploadFile, status
|
||||
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Path, Query, Request, Response, UploadFile, status
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
from skyvern import analytics
|
||||
@@ -295,8 +295,8 @@ async def cancel_workflow_run(
|
||||
)
|
||||
# get all the child workflow runs and cancel them
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
organization_id=current_org.organization_id,
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
for child_workflow_run in child_workflow_runs:
|
||||
if child_workflow_run.status not in [
|
||||
@@ -456,7 +456,7 @@ async def get_runs(
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_run(
|
||||
run_id: str,
|
||||
run_id: str = Path(..., description="The id of the task run or the workflow run."),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> RunResponse:
|
||||
run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id)
|
||||
@@ -1468,15 +1468,15 @@ async def run_task(
|
||||
url = run_request.url
|
||||
data_extraction_goal = None
|
||||
data_extraction_schema = run_request.data_extraction_schema
|
||||
navigation_goal = run_request.goal
|
||||
navigation_goal = run_request.prompt
|
||||
navigation_payload = None
|
||||
if not url:
|
||||
task_generation = await task_v1_service.generate_task(
|
||||
user_prompt=run_request.goal,
|
||||
user_prompt=run_request.prompt,
|
||||
organization=current_org,
|
||||
)
|
||||
url = task_generation.url
|
||||
navigation_goal = task_generation.navigation_goal or run_request.goal
|
||||
navigation_goal = task_generation.navigation_goal or run_request.prompt
|
||||
navigation_payload = task_generation.navigation_payload
|
||||
data_extraction_goal = task_generation.data_extraction_goal
|
||||
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
|
||||
@@ -1642,3 +1642,24 @@ async def run_workflow(
|
||||
downloaded_files=None,
|
||||
recording_url=None,
|
||||
)
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/runs/{run_id}/cancel",
|
||||
tags=["Agent"],
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "agent",
|
||||
"x-fern-sdk-method-name": "cancel_run",
|
||||
},
|
||||
description="Cancel a task or workflow run",
|
||||
summary="Cancel a task or workflow run",
|
||||
)
|
||||
@base_router.post("/runs/{run_id}/cancel/", include_in_schema=False)
|
||||
async def cancel_run(
|
||||
run_id: str = Path(..., description="The id of the task run or the workflow run to cancel."),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> None:
|
||||
analytics.capture("skyvern-oss-agent-cancel-run")
|
||||
|
||||
await run_service.cancel_run(run_id, organization_id=current_org.organization_id, api_key=x_api_key)
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from skyvern.exceptions import TaskNotFound, WorkflowRunNotFound
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
|
||||
from skyvern.services import task_v2_service
|
||||
|
||||
|
||||
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
|
||||
@@ -72,3 +78,69 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
||||
# modified_at=run.modified_at,
|
||||
# )
|
||||
raise ValueError(f"Invalid task run type: {run.task_run_type}")
|
||||
|
||||
|
||||
async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
|
||||
task = await app.DATABASE.get_task(task_id, organization_id=organization_id)
|
||||
if not task:
|
||||
raise TaskNotFound(task_id=task_id)
|
||||
task = await app.agent.update_task(task, status=TaskStatus.canceled)
|
||||
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization_id)
|
||||
await app.agent.execute_task_webhook(task=task, last_step=latest_step, api_key=api_key)
|
||||
|
||||
|
||||
async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> None:
|
||||
task_v2 = await app.DATABASE.get_task_v2(task_id, organization_id=organization_id)
|
||||
if not task_v2:
|
||||
raise TaskNotFound(task_id=task_id)
|
||||
await task_v2_service.mark_task_v2_as_canceled(
|
||||
task_v2_id=task_id, workflow_run_id=task_v2.workflow_run_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
|
||||
async def cancel_workflow_run(
|
||||
workflow_run_id: str, organization_id: str | None = None, api_key: str | None = None
|
||||
) -> None:
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
|
||||
|
||||
# get all the child workflow runs and cancel them
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
for child_workflow_run in child_workflow_runs:
|
||||
if child_workflow_run.status not in [
|
||||
WorkflowRunStatus.running,
|
||||
WorkflowRunStatus.created,
|
||||
WorkflowRunStatus.queued,
|
||||
]:
|
||||
continue
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id)
|
||||
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
|
||||
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=api_key)
|
||||
|
||||
|
||||
async def cancel_run(run_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
|
||||
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
|
||||
if not run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Run not found {run_id}",
|
||||
)
|
||||
|
||||
if run.task_run_type == RunType.task_v1:
|
||||
await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key)
|
||||
elif run.task_run_type == RunType.task_v2:
|
||||
await cancel_task_v2(run_id, organization_id=organization_id)
|
||||
elif run.task_run_type == RunType.workflow_run:
|
||||
await cancel_workflow_run(run_id, organization_id=organization_id, api_key=api_key)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid run type to cancel: {run.task_run_type}",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user