cancel run API (#2067)
This commit is contained in:
@@ -1547,8 +1547,8 @@ class AgentDB:
|
|||||||
|
|
||||||
async def get_workflow_runs_by_parent_workflow_run_id(
|
async def get_workflow_runs_by_parent_workflow_run_id(
|
||||||
self,
|
self,
|
||||||
organization_id: str,
|
|
||||||
parent_workflow_run_id: str,
|
parent_workflow_run_id: str,
|
||||||
|
organization_id: str | None = None,
|
||||||
) -> list[WorkflowRun]:
|
) -> list[WorkflowRun]:
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Annotated, Any
|
|||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, UploadFile, status
|
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Path, Query, Request, Response, UploadFile, status
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
|
|
||||||
from skyvern import analytics
|
from skyvern import analytics
|
||||||
@@ -295,8 +295,8 @@ async def cancel_workflow_run(
|
|||||||
)
|
)
|
||||||
# get all the child workflow runs and cancel them
|
# get all the child workflow runs and cancel them
|
||||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||||
organization_id=current_org.organization_id,
|
|
||||||
parent_workflow_run_id=workflow_run_id,
|
parent_workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=current_org.organization_id,
|
||||||
)
|
)
|
||||||
for child_workflow_run in child_workflow_runs:
|
for child_workflow_run in child_workflow_runs:
|
||||||
if child_workflow_run.status not in [
|
if child_workflow_run.status not in [
|
||||||
@@ -456,7 +456,7 @@ async def get_runs(
|
|||||||
include_in_schema=False,
|
include_in_schema=False,
|
||||||
)
|
)
|
||||||
async def get_run(
|
async def get_run(
|
||||||
run_id: str,
|
run_id: str = Path(..., description="The id of the task run or the workflow run."),
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> RunResponse:
|
) -> RunResponse:
|
||||||
run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id)
|
run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id)
|
||||||
@@ -1468,15 +1468,15 @@ async def run_task(
|
|||||||
url = run_request.url
|
url = run_request.url
|
||||||
data_extraction_goal = None
|
data_extraction_goal = None
|
||||||
data_extraction_schema = run_request.data_extraction_schema
|
data_extraction_schema = run_request.data_extraction_schema
|
||||||
navigation_goal = run_request.goal
|
navigation_goal = run_request.prompt
|
||||||
navigation_payload = None
|
navigation_payload = None
|
||||||
if not url:
|
if not url:
|
||||||
task_generation = await task_v1_service.generate_task(
|
task_generation = await task_v1_service.generate_task(
|
||||||
user_prompt=run_request.goal,
|
user_prompt=run_request.prompt,
|
||||||
organization=current_org,
|
organization=current_org,
|
||||||
)
|
)
|
||||||
url = task_generation.url
|
url = task_generation.url
|
||||||
navigation_goal = task_generation.navigation_goal or run_request.goal
|
navigation_goal = task_generation.navigation_goal or run_request.prompt
|
||||||
navigation_payload = task_generation.navigation_payload
|
navigation_payload = task_generation.navigation_payload
|
||||||
data_extraction_goal = task_generation.data_extraction_goal
|
data_extraction_goal = task_generation.data_extraction_goal
|
||||||
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
|
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
|
||||||
@@ -1642,3 +1642,24 @@ async def run_workflow(
|
|||||||
downloaded_files=None,
|
downloaded_files=None,
|
||||||
recording_url=None,
|
recording_url=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.post(
|
||||||
|
"/runs/{run_id}/cancel",
|
||||||
|
tags=["Agent"],
|
||||||
|
openapi_extra={
|
||||||
|
"x-fern-sdk-group-name": "agent",
|
||||||
|
"x-fern-sdk-method-name": "cancel_run",
|
||||||
|
},
|
||||||
|
description="Cancel a task or workflow run",
|
||||||
|
summary="Cancel a task or workflow run",
|
||||||
|
)
|
||||||
|
@base_router.post("/runs/{run_id}/cancel/", include_in_schema=False)
|
||||||
|
async def cancel_run(
|
||||||
|
run_id: str = Path(..., description="The id of the task run or the workflow run to cancel."),
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
x_api_key: Annotated[str | None, Header()] = None,
|
||||||
|
) -> None:
|
||||||
|
analytics.capture("skyvern-oss-agent-cancel-run")
|
||||||
|
|
||||||
|
await run_service.cancel_run(run_id, organization_id=current_org.organization_id, api_key=x_api_key)
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
from skyvern.exceptions import TaskNotFound, WorkflowRunNotFound
|
||||||
from skyvern.forge import app
|
from skyvern.forge import app
|
||||||
|
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
||||||
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
||||||
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
|
from skyvern.schemas.runs import RunEngine, RunResponse, RunType, TaskRunRequest, TaskRunResponse
|
||||||
|
from skyvern.services import task_v2_service
|
||||||
|
|
||||||
|
|
||||||
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
|
async def get_run_response(run_id: str, organization_id: str | None = None) -> RunResponse | None:
|
||||||
@@ -72,3 +78,69 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
|||||||
# modified_at=run.modified_at,
|
# modified_at=run.modified_at,
|
||||||
# )
|
# )
|
||||||
raise ValueError(f"Invalid task run type: {run.task_run_type}")
|
raise ValueError(f"Invalid task run type: {run.task_run_type}")
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_task_v1(task_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
|
||||||
|
task = await app.DATABASE.get_task(task_id, organization_id=organization_id)
|
||||||
|
if not task:
|
||||||
|
raise TaskNotFound(task_id=task_id)
|
||||||
|
task = await app.agent.update_task(task, status=TaskStatus.canceled)
|
||||||
|
latest_step = await app.DATABASE.get_latest_step(task_id, organization_id=organization_id)
|
||||||
|
await app.agent.execute_task_webhook(task=task, last_step=latest_step, api_key=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_task_v2(task_id: str, organization_id: str | None = None) -> None:
|
||||||
|
task_v2 = await app.DATABASE.get_task_v2(task_id, organization_id=organization_id)
|
||||||
|
if not task_v2:
|
||||||
|
raise TaskNotFound(task_id=task_id)
|
||||||
|
await task_v2_service.mark_task_v2_as_canceled(
|
||||||
|
task_v2_id=task_id, workflow_run_id=task_v2.workflow_run_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_workflow_run(
|
||||||
|
workflow_run_id: str, organization_id: str | None = None, api_key: str | None = None
|
||||||
|
) -> None:
|
||||||
|
workflow_run = await app.DATABASE.get_workflow_run(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
if not workflow_run:
|
||||||
|
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
|
||||||
|
|
||||||
|
# get all the child workflow runs and cancel them
|
||||||
|
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||||
|
parent_workflow_run_id=workflow_run_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
for child_workflow_run in child_workflow_runs:
|
||||||
|
if child_workflow_run.status not in [
|
||||||
|
WorkflowRunStatus.running,
|
||||||
|
WorkflowRunStatus.created,
|
||||||
|
WorkflowRunStatus.queued,
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id)
|
||||||
|
await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id)
|
||||||
|
await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_run(run_id: str, organization_id: str | None = None, api_key: str | None = None) -> None:
|
||||||
|
run = await app.DATABASE.get_run(run_id, organization_id=organization_id)
|
||||||
|
if not run:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Run not found {run_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if run.task_run_type == RunType.task_v1:
|
||||||
|
await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key)
|
||||||
|
elif run.task_run_type == RunType.task_v2:
|
||||||
|
await cancel_task_v2(run_id, organization_id=organization_id)
|
||||||
|
elif run.task_run_type == RunType.workflow_run:
|
||||||
|
await cancel_workflow_run(run_id, organization_id=organization_id, api_key=api_key)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Invalid run type to cancel: {run.task_run_type}",
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user