From 28d37545bcf197fb1a27a831f45e115198c08a58 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 14 Nov 2024 01:32:53 -0800 Subject: [PATCH] Implement cancel workflow run endpoint (#1188) --- skyvern/forge/agent.py | 23 ++++++++++- skyvern/forge/sdk/routes/agent_protocol.py | 9 +++++ skyvern/forge/sdk/workflow/models/block.py | 1 - skyvern/forge/sdk/workflow/service.py | 47 +++++++++++++++++++++- 4 files changed, 76 insertions(+), 4 deletions(-) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 0d5912ad..924cd6b4 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -44,7 +44,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskStatus from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.models.block import TaskBlock -from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun +from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus from skyvern.webeye.actions.actions import ( Action, ActionType, @@ -220,10 +220,29 @@ class ForgeAgent: task: Task, step: Step, api_key: str | None = None, - workflow_run: WorkflowRun | None = None, close_browser_on_completion: bool = True, task_block: TaskBlock | None = None, ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: + workflow_run: WorkflowRun | None = None + if task.workflow_run_id: + workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=task.workflow_run_id) + if workflow_run and workflow_run.status == WorkflowRunStatus.canceled: + LOG.info( + "Workflow run is canceled, stopping execution inside task", + workflow_run_id=workflow_run.workflow_run_id, + step_id=step.step_id, + ) + step = await self.update_step( + step, + status=StepStatus.canceled, + is_last=True, + ) + task = await self.update_task( + task, + status=TaskStatus.canceled, + ) + return step, None, None + refreshed_task = await app.DATABASE.get_task(task_id=task.task_id, organization_id=organization.organization_id) if refreshed_task: task = refreshed_task diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 366c7791..d2b42b65 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -345,6 +345,15 @@ async def cancel_task( await app.agent.update_task(task_obj, status=TaskStatus.canceled) +@base_router.post("/workflows/runs/{workflow_run_id}/cancel") +@base_router.post("/workflows/runs/{workflow_run_id}/cancel/", include_in_schema=False) +async def cancel_workflow_run( + workflow_run_id: str, + current_org: Organization = Depends(org_auth_service.get_current_org), +) -> None: + await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id) + + @base_router.post( "/tasks/{task_id}/retry_webhook", tags=["agent"], diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 8c326ba7..90a3a39a 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -335,7 +335,6 @@ class TaskBlock(Block): organization=organization, task=task, step=step, - workflow_run=workflow_run, task_block=self, ) except Exception as e: diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 07ef3f79..bfbc2c39 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -188,6 +188,24 @@ class WorkflowService: for block_idx, block in enumerate(blocks): is_last_block = block_idx + 1 == blocks_cnt try: + refreshed_workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run.workflow_run_id + ) + if refreshed_workflow_run and refreshed_workflow_run.status == WorkflowRunStatus.canceled: + LOG.info( + "Workflow run is canceled, stopping execution inside workflow execution loop", + workflow_run_id=workflow_run.workflow_run_id, + block_idx=block_idx, + block_type=block.block_type, + block_label=block.label, + ) + await self.clean_up_workflow( + workflow=workflow, + workflow_run=workflow_run, + api_key=api_key, + need_call_webhook=False, + ) + return workflow_run parameters = block.get_all_parameters(workflow_run_id) await app.WORKFLOW_CONTEXT_MANAGER.register_block_parameters_for_workflow_run( workflow_run_id, parameters, organization @@ -197,6 +215,8 @@ class WorkflowService: block_type=block.block_type, workflow_run_id=workflow_run.workflow_run_id, block_idx=block_idx, + block_type_var=block.block_type, + block_label=block.label, ) block_result = await block.execute_safe(workflow_run_id=workflow_run_id) if block_result.status == BlockStatus.canceled: @@ -206,6 +226,8 @@ class WorkflowService: workflow_run_id=workflow_run.workflow_run_id, block_idx=block_idx, block_result=block_result, + block_type_var=block.block_type, + block_label=block.label, ) await self.mark_workflow_run_as_canceled(workflow_run_id=workflow_run.workflow_run_id) # We're not sending a webhook here because the workflow run is manually marked as canceled. @@ -223,6 +245,8 @@ class WorkflowService: workflow_run_id=workflow_run.workflow_run_id, block_idx=block_idx, block_result=block_result, + block_type_var=block.block_type, + block_label=block.label, ) if block.continue_on_failure and not is_last_block: LOG.warning( @@ -232,6 +256,8 @@ class WorkflowService: block_idx=block_idx, block_result=block_result, continue_on_failure=block.continue_on_failure, + block_type_var=block.block_type, + block_label=block.label, ) else: await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id) @@ -248,6 +274,8 @@ class WorkflowService: workflow_run_id=workflow_run.workflow_run_id, block_idx=block_idx, block_result=block_result, + block_type_var=block.block_type, + block_label=block.label, ) if block.continue_on_failure and not is_last_block: LOG.warning( @@ -257,6 +285,8 @@ class WorkflowService: block_idx=block_idx, block_result=block_result, continue_on_failure=block.continue_on_failure, + block_type_var=block.block_type, + block_label=block.label, ) else: await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id) @@ -270,12 +300,27 @@ class WorkflowService: LOG.exception( f"Error while executing workflow run {workflow_run.workflow_run_id}", workflow_run_id=workflow_run.workflow_run_id, + block_idx=block_idx, + block_type=block.block_type, + block_label=block.label, ) await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id) await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key) return workflow_run - await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id) + refreshed_workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run.workflow_run_id) + if refreshed_workflow_run and refreshed_workflow_run.status not in ( + WorkflowRunStatus.canceled, + WorkflowRunStatus.failed, + WorkflowRunStatus.terminated, + ): + await self.mark_workflow_run_as_completed(workflow_run_id=workflow_run.workflow_run_id) + else: + LOG.info( + "Workflow run is already canceled, failed, or terminated, not marking as completed", + workflow_run_id=workflow_run.workflow_run_id, + workflow_run_status=refreshed_workflow_run.status if refreshed_workflow_run else None, + ) await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key) return workflow_run