From 4df0daa2ea19eea1ca2492a6286a0b291c669908 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 18 Feb 2025 23:21:17 +0800 Subject: [PATCH] handle workflow run cancel for child workflow runs / task v2 + observer cancel handling (#1776) --- skyvern/forge/sdk/db/client.py | 18 +++++ skyvern/forge/sdk/routes/agent_protocol.py | 13 +++ .../forge/sdk/services/observer_service.py | 80 ++++++++++++++----- 3 files changed, 93 insertions(+), 18 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index ec9b4037..224ba1cb 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1511,6 +1511,24 @@ class AgentDB: LOG.error("SQLAlchemyError", exc_info=True) raise + async def get_workflow_runs_by_parent_workflow_run_id( + self, + organization_id: str, + parent_workflow_run_id: str, + ) -> list[WorkflowRun]: + try: + async with self.Session() as session: + query = ( + select(WorkflowRunModel) + .filter(WorkflowRunModel.organization_id == organization_id) + .filter(WorkflowRunModel.parent_workflow_run_id == parent_workflow_run_id) + ) + workflow_runs = (await session.scalars(query)).all() + return [convert_to_workflow_run(run) for run in workflow_runs] + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + async def create_workflow_parameter( self, workflow_id: str, diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index b5a5dbf4..dff1c66b 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -339,6 +339,19 @@ async def cancel_workflow_run( status_code=status.HTTP_404_NOT_FOUND, detail=f"Workflow run not found {workflow_run_id}", ) + # get all the child workflow runs and cancel them + child_workflow_runs = await app.DATABASE.get_workflow_runs_by_parent_workflow_run_id( + organization_id=current_org.organization_id, + parent_workflow_run_id=workflow_run_id, + ) + for child_workflow_run in child_workflow_runs: + if child_workflow_run.status not in [ + WorkflowRunStatus.running, + WorkflowRunStatus.created, + WorkflowRunStatus.queued, + ]: + continue + await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(child_workflow_run.workflow_run_id) await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id) await app.WORKFLOW_SERVICE.execute_workflow_webhook(workflow_run, api_key=x_api_key) diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index 5fdf064c..8f8cdf32 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -163,9 +163,11 @@ async def initialize_observer_task( except Exception: LOG.error("Failed to setup cruise workflow run", exc_info=True) # fail the workflow run - await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( + await mark_observer_task_as_failed( + observer_cruise_id=observer_task.observer_cruise_id, workflow_run_id=workflow_run.workflow_run_id, failure_reason="Skyvern failed to setup the workflow run", + organization_id=organization.organization_id, ) raise @@ -204,9 +206,11 @@ async def initialize_observer_task( except Exception: LOG.warning("Failed to update task 2.0", exc_info=True) # fail the workflow run - await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( + await mark_observer_task_as_failed( + observer_cruise_id=observer_task.observer_cruise_id, workflow_run_id=workflow_run.workflow_run_id, failure_reason="Skyvern failed to update the task 2.0 after initializing the workflow run", + organization_id=organization.organization_id, ) raise @@ -225,14 +229,18 @@ async def run_observer_task( observer_task = await app.DATABASE.get_observer_cruise(observer_cruise_id, organization_id=organization_id) except Exception: LOG.error( - "Failed to get observer cruise", + "Failed to get observer task", observer_cruise_id=observer_cruise_id, organization_id=organization_id, exc_info=True, ) - return await mark_observer_task_as_failed(observer_cruise_id, organization_id=organization_id) + return await mark_observer_task_as_failed( + observer_cruise_id, + organization_id=organization_id, + failure_reason="Failed to get task v2", + ) if not observer_task: - LOG.error("Observer cruise not found", observer_cruise_id=observer_cruise_id, organization_id=organization_id) + LOG.error("Task v2 not found", observer_cruise_id=observer_cruise_id, organization_id=organization_id) raise ObserverCruiseNotFound(observer_cruise_id=observer_cruise_id) workflow, workflow_run = None, None @@ -365,6 +373,25 @@ async def run_observer_task_helper( max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS for i in range(max_iterations): + # check the status of the workflow run + workflow_run = await app.WORKFLOW_SERVICE.get_workflow_run(workflow_run_id, organization_id=organization_id) + if not workflow_run: + LOG.error("Workflow run not found", workflow_run_id=workflow_run_id) + break + + if workflow_run.status == WorkflowRunStatus.canceled: + LOG.info( + "Task v2 is canceled. Stopping task v2", + workflow_run_id=workflow_run_id, + observer_task_id=observer_cruise_id, + ) + await mark_observer_task_as_canceled( + observer_cruise_id, + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return workflow, workflow_run, observer_task + LOG.info(f"Observer iteration i={i}", workflow_run_id=workflow_run_id, url=url) task_type = "" plan = "" @@ -472,7 +499,8 @@ async def run_observer_task_helper( # parse observer repsonse and run the next task if not task_type: LOG.error("No task type found in observer response", observer_response=observer_response) - await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( + await mark_observer_task_as_failed( + observer_cruise_id=observer_cruise_id, workflow_run_id=workflow_run_id, failure_reason="Skyvern failed to generate a task. Please try again later.", ) @@ -523,14 +551,16 @@ async def run_observer_task_helper( } except Exception: LOG.exception("Failed to generate loop task") - await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( + await mark_observer_task_as_failed( + observer_cruise_id=observer_cruise_id, workflow_run_id=workflow_run_id, failure_reason="Failed to generate the loop.", ) break else: LOG.info("Unsupported task type", task_type=task_type) - await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( + await mark_observer_task_as_failed( + observer_cruise_id=observer_cruise_id, workflow_run_id=workflow_run_id, failure_reason=f"Unsupported task block type gets generated: {task_type}", ) @@ -580,6 +610,7 @@ async def run_observer_task_helper( # execute the extraction task workflow_run = await handle_block_result( + observer_cruise_id, block, block_result, workflow, @@ -680,6 +711,7 @@ async def run_observer_task_helper( async def handle_block_result( + observer_cruise_id: str, block: BlockTypeVar, block_result: BlockResult, workflow: Workflow, @@ -697,7 +729,11 @@ async def handle_block_result( block_type_var=block.block_type, block_label=block.label, ) - await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id=workflow_run.workflow_run_id) + await mark_observer_task_as_canceled( + observer_cruise_id=observer_cruise_id, + workflow_run_id=workflow_run_id, + organization_id=workflow_run.organization_id, + ) elif block_result.status == BlockStatus.failed: LOG.error( @@ -826,11 +862,7 @@ async def _generate_loop_task( "Failed to execute the extraction block for the loop task", extraction_block_result=extraction_block_result, ) - # TODO: fail the workflow run - await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( - workflow_run_id=workflow_run_id, - failure_reason="Failed to extract loop values for the loop. Please try again later.", - ) + # wofklow run and observer task status update is handled in the upper caller layer raise Exception("extraction_block failed") # validate output parameter try: @@ -848,10 +880,6 @@ async def _generate_loop_task( "Failed to validate the output parameter of the extraction block for the loop task", extraction_block_result=extraction_block_result, ) - await app.WORKFLOW_SERVICE.mark_workflow_run_as_failed( - workflow_run_id=workflow_run_id, - failure_reason="Invalid output parameter of the extraction block for the loop. Please try again later.", - ) raise # update the observer thought @@ -1207,6 +1235,22 @@ async def mark_observer_task_as_completed( return observer_task +async def mark_observer_task_as_canceled( + observer_cruise_id: str, + workflow_run_id: str | None = None, + organization_id: str | None = None, +) -> ObserverTask: + observer_task = await app.DATABASE.update_observer_cruise( + observer_cruise_id, + organization_id=organization_id, + status=ObserverTaskStatus.canceled, + ) + if workflow_run_id: + await app.WORKFLOW_SERVICE.mark_workflow_run_as_canceled(workflow_run_id) + await send_observer_task_webhook(observer_task) + return observer_task + + def _get_extracted_data_from_block_result( block_result: BlockResult, task_type: str,