cancel run API (#2067)

This commit is contained in:
Shuchang Zheng
2025-04-01 23:48:39 -04:00
committed by GitHub
parent f317b71468
commit c664cfb5a9
3 changed files with 100 additions and 7 deletions

View File

@@ -1547,8 +1547,8 @@ class AgentDB:
async def get_workflow_runs_by_parent_workflow_run_id(
self,
organization_id: str,
parent_workflow_run_id: str,
organization_id: str | None = None,
) -> list[WorkflowRun]:
try:
async with self.Session() as session:

View File

@@ -6,7 +6,7 @@ from typing import Annotated, Any
import structlog
import yaml
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, UploadFile, status
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Path, Query, Request, Response, UploadFile, status
from fastapi.responses import ORJSONResponse
from skyvern import analytics
@@ -295,8 +295,8 @@ async def cancel_workflow_run(
)
# get all the child workflow runs and cancel them
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
organization_id=current_org.organization_id,
parent_workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id,
)
for child_workflow_run in child_workflow_runs:
if child_workflow_run.status not in [
@@ -456,7 +456,7 @@ async def get_runs(
include_in_schema=False,
)
async def get_run(
run_id: str,
run_id: str = Path(..., description="The id of the task run or the workflow run."),
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> RunResponse:
run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id)
@@ -1468,15 +1468,15 @@ async def run_task(
url = run_request.url
data_extraction_goal = None
data_extraction_schema = run_request.data_extraction_schema
navigation_goal = run_request.goal
navigation_goal = run_request.prompt
navigation_payload = None
if not url:
task_generation = await task_v1_service.generate_task(
user_prompt=run_request.goal,
user_prompt=run_request.prompt,
organization=current_org,
)
url = task_generation.url
navigation_goal = task_generation.navigation_goal or run_request.goal
navigation_goal = task_generation.navigation_goal or run_request.prompt
navigation_payload = task_generation.navigation_payload
data_extraction_goal = task_generation.data_extraction_goal
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
@@ -1642,3 +1642,24 @@ async def run_workflow(
downloaded_files=None,
recording_url=None,
)
@base_router.post(
"/runs/{run_id}/cancel",
tags=["Agent"],
openapi_extra={
"x-fern-sdk-group-name": "agent",
"x-fern-sdk-method-name": "cancel_run",
},
description="Cancel a task or workflow run",
summary="Cancel a task or workflow run",
)
@base_router.post("/runs/{run_id}/cancel/", include_in_schema=False)
async def cancel_run(
run_id: str = Path(..., description="The id of the task run or the workflow run to cancel."),
current_org: Organization = Depends(org_auth_service.get_current_org),
x_api_key: Annotated[str | None, Header()] = None,
) -> None:
analytics.capture("skyvern-oss-agent-cancel-run")
await run_service.cancel_run(run_id, organization_id=current_org.organization_id, api_key=x_api_key)