cancel run API (#2067)
This commit is contained in:
@@ -1547,8 +1547,8 @@ class AgentDB:
|
||||
|
||||
async def get_workflow_runs_by_parent_workflow_run_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
parent_workflow_run_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[WorkflowRun]:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Annotated, Any
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Query, Request, Response, UploadFile, status
|
||||
from fastapi import BackgroundTasks, Depends, Header, HTTPException, Path, Query, Request, Response, UploadFile, status
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
from skyvern import analytics
|
||||
@@ -295,8 +295,8 @@ async def cancel_workflow_run(
|
||||
)
|
||||
# get all the child workflow runs and cancel them
|
||||
child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id(
|
||||
organization_id=current_org.organization_id,
|
||||
parent_workflow_run_id=workflow_run_id,
|
||||
organization_id=current_org.organization_id,
|
||||
)
|
||||
for child_workflow_run in child_workflow_runs:
|
||||
if child_workflow_run.status not in [
|
||||
@@ -456,7 +456,7 @@ async def get_runs(
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_run(
|
||||
run_id: str,
|
||||
run_id: str = Path(..., description="The id of the task run or the workflow run."),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
) -> RunResponse:
|
||||
run_response = await run_service.get_run_response(run_id, organization_id=current_org.organization_id)
|
||||
@@ -1468,15 +1468,15 @@ async def run_task(
|
||||
url = run_request.url
|
||||
data_extraction_goal = None
|
||||
data_extraction_schema = run_request.data_extraction_schema
|
||||
navigation_goal = run_request.goal
|
||||
navigation_goal = run_request.prompt
|
||||
navigation_payload = None
|
||||
if not url:
|
||||
task_generation = await task_v1_service.generate_task(
|
||||
user_prompt=run_request.goal,
|
||||
user_prompt=run_request.prompt,
|
||||
organization=current_org,
|
||||
)
|
||||
url = task_generation.url
|
||||
navigation_goal = task_generation.navigation_goal or run_request.goal
|
||||
navigation_goal = task_generation.navigation_goal or run_request.prompt
|
||||
navigation_payload = task_generation.navigation_payload
|
||||
data_extraction_goal = task_generation.data_extraction_goal
|
||||
data_extraction_schema = data_extraction_schema or task_generation.extracted_information_schema
|
||||
@@ -1642,3 +1642,24 @@ async def run_workflow(
|
||||
downloaded_files=None,
|
||||
recording_url=None,
|
||||
)
|
||||
|
||||
|
||||
@base_router.post(
|
||||
"/runs/{run_id}/cancel",
|
||||
tags=["Agent"],
|
||||
openapi_extra={
|
||||
"x-fern-sdk-group-name": "agent",
|
||||
"x-fern-sdk-method-name": "cancel_run",
|
||||
},
|
||||
description="Cancel a task or workflow run",
|
||||
summary="Cancel a task or workflow run",
|
||||
)
|
||||
@base_router.post("/runs/{run_id}/cancel/", include_in_schema=False)
|
||||
async def cancel_run(
|
||||
run_id: str = Path(..., description="The id of the task run or the workflow run to cancel."),
|
||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_api_key: Annotated[str | None, Header()] = None,
|
||||
) -> None:
|
||||
analytics.capture("skyvern-oss-agent-cancel-run")
|
||||
|
||||
await run_service.cancel_run(run_id, organization_id=current_org.organization_id, api_key=x_api_key)
|
||||
|
||||
Reference in New Issue
Block a user