From 9821e8f95eb3a9e1f68ffb3b0538ecf8cb977754 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Tue, 8 Oct 2024 23:09:41 -0700 Subject: [PATCH] Utilize all Workflow Run statuses (#935) --- skyvern/forge/sdk/workflow/models/block.py | 89 ++++++++++++--- skyvern/forge/sdk/workflow/models/workflow.py | 12 +++ skyvern/forge/sdk/workflow/service.py | 102 +++++++++--------- 3 files changed, 136 insertions(+), 67 deletions(-) diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 038ac805..f94ff3aa 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -68,11 +68,19 @@ class BlockType(StrEnum): FILE_URL_PARSER = "file_url_parser" +class BlockStatus(StrEnum): + completed = "completed" + failed = "failed" + terminated = "terminated" + canceled = "canceled" + + @dataclass(frozen=True) class BlockResult: success: bool output_parameter: OutputParameter output_parameter_value: dict[str, Any] | list | str | None = None + status: BlockStatus | None = None class Block(BaseModel, abc.ABC): @@ -107,11 +115,13 @@ class Block(BaseModel, abc.ABC): self, success: bool, output_parameter_value: dict[str, Any] | list | str | None = None, + status: BlockStatus | None = None, ) -> BlockResult: return BlockResult( success=success, output_parameter=self.output_parameter, output_parameter_value=output_parameter_value, + status=status, ) @classmethod @@ -144,7 +154,7 @@ class Block(BaseModel, abc.ABC): workflow_run_context = self.get_workflow_run_context(workflow_run_id) if not workflow_run_context.has_value(self.output_parameter.key): await self.record_output_parameter_value(workflow_run_context, workflow_run_id) - return self.build_block_result(success=False) + return self.build_block_result(success=False, status=BlockStatus.failed) @abc.abstractmethod def get_all_parameters( @@ -333,6 +343,12 @@ class TaskBlock(Block): if not updated_task.status.is_final(): raise UnexpectedTaskStatus(task_id=updated_task.task_id, status=updated_task.status) + block_status_mapping = { + TaskStatus.completed: BlockStatus.completed, + TaskStatus.terminated: BlockStatus.terminated, + TaskStatus.failed: BlockStatus.failed, + TaskStatus.canceled: BlockStatus.canceled, + } if updated_task.status == TaskStatus.completed or updated_task.status == TaskStatus.terminated: LOG.info( "Task completed", @@ -346,7 +362,23 @@ class TaskBlock(Block): task_output = TaskOutput.from_task(updated_task) output_parameter_value = task_output.model_dump() await self.record_output_parameter_value(workflow_run_context, workflow_run_id, output_parameter_value) - return self.build_block_result(success=success, output_parameter_value=output_parameter_value) + return self.build_block_result( + success=success, + output_parameter_value=output_parameter_value, + status=block_status_mapping[updated_task.status], + ) + elif updated_task.status == TaskStatus.canceled: + LOG.info( + "Task canceled, cancelling block", + task_id=updated_task.task_id, + task_status=updated_task.status, + workflow_run_id=workflow_run_id, + workflow_id=workflow.workflow_id, + organization_id=workflow.organization_id, + ) + return self.build_block_result( + success=False, output_parameter_value=None, status=block_status_mapping[updated_task.status] + ) else: current_retry += 1 will_retry = current_retry <= self.max_retries @@ -368,10 +400,14 @@ class TaskBlock(Block): await self.record_output_parameter_value( workflow_run_context, workflow_run_id, output_parameter_value ) - return self.build_block_result(success=False, output_parameter_value=output_parameter_value) + return self.build_block_result( + success=False, + output_parameter_value=output_parameter_value, + status=block_status_mapping[updated_task.status], + ) await self.record_output_parameter_value(workflow_run_context, workflow_run_id) - return self.build_block_result(success=False) + return self.build_block_result(success=False, status=BlockStatus.failed) class ForLoopBlock(Block): @@ -455,7 +491,7 @@ class ForLoopBlock(Block): return [parameter_value] async def execute(self, workflow_run_id: str, **kwargs: dict) -> BlockResult: - outputs_with_loop_values = [] + outputs_with_loop_values: list[list[dict[str, Any]]] = [] success = False workflow_run_context = self.get_workflow_run_context(workflow_run_id) loop_over_values = self.get_loop_over_parameter_values(workflow_run_context) @@ -467,13 +503,13 @@ class ForLoopBlock(Block): ) if not loop_over_values or len(loop_over_values) == 0: LOG.info( - "No loop_over values found", + "No loop_over values found, terminating block", block_type=self.block_type, workflow_run_id=workflow_run_id, num_loop_over_values=len(loop_over_values), ) await self.record_output_parameter_value(workflow_run_context, workflow_run_id, []) - return self.build_block_result(success=False) + return self.build_block_result(success=False, status=BlockStatus.terminated) for loop_idx, loop_over_value in enumerate(loop_over_values): context_parameters_with_value = self.get_loop_block_context_parameters(workflow_run_id, loop_over_value) for context_parameter in context_parameters_with_value: @@ -483,6 +519,21 @@ class ForLoopBlock(Block): original_loop_block = loop_block loop_block = loop_block.copy() block_output = await loop_block.execute_safe(workflow_run_id=workflow_run_id) + if block_output.status == BlockStatus.canceled: + LOG.info( + f"ForLoopBlock: Block with type {loop_block.block_type} at index {block_idx} was canceled for workflow run {workflow_run_id}, canceling for loop", + block_type=loop_block.block_type, + workflow_run_id=workflow_run_id, + block_idx=block_idx, + block_result=block_output, + ) + await self.record_output_parameter_value( + workflow_run_context, workflow_run_id, outputs_with_loop_values + ) + return self.build_block_result( + success=False, output_parameter_value=outputs_with_loop_values, status=BlockStatus.canceled + ) + loop_block = original_loop_block block_outputs.append(block_output) if not block_output.success and not loop_block.continue_on_failure: @@ -518,8 +569,16 @@ class ForLoopBlock(Block): ) break + is_any_block_terminated = any([block_output.status == BlockStatus.terminated for block_output in block_outputs]) + for_loop_block_status = BlockStatus.completed + if is_any_block_terminated: + for_loop_block_status = BlockStatus.terminated + elif not success: + for_loop_block_status = BlockStatus.failed await self.record_output_parameter_value(workflow_run_context, workflow_run_id, outputs_with_loop_values) - return self.build_block_result(success=success, output_parameter_value=outputs_with_loop_values) + return self.build_block_result( + success=success, output_parameter_value=outputs_with_loop_values, status=for_loop_block_status + ) class CodeBlock(Block): @@ -578,7 +637,7 @@ async def user_code(): result = {"result": result_container.get("result")} await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result) - return self.build_block_result(success=True, output_parameter_value=result) + return self.build_block_result(success=True, output_parameter_value=result, status=BlockStatus.completed) class TextPromptBlock(Block): @@ -640,7 +699,7 @@ class TextPromptBlock(Block): response = await self.send_prompt(self.prompt, parameter_values) await self.record_output_parameter_value(workflow_run_context, workflow_run_id, response) - return self.build_block_result(success=True, output_parameter_value=response) + return self.build_block_result(success=True, output_parameter_value=response, status=BlockStatus.completed) class DownloadToS3Block(Block): @@ -697,7 +756,7 @@ class DownloadToS3Block(Block): LOG.info("DownloadToS3Block: File downloaded and uploaded to S3", uri=uri) await self.record_output_parameter_value(workflow_run_context, workflow_run_id, uri) - return self.build_block_result(success=True, output_parameter_value=uri) + return self.build_block_result(success=True, output_parameter_value=uri, status=BlockStatus.completed) class UploadToS3Block(Block): @@ -771,7 +830,7 @@ class UploadToS3Block(Block): LOG.info("UploadToS3Block: File(s) uploaded to S3", file_path=self.path) await self.record_output_parameter_value(workflow_run_context, workflow_run_id, s3_uris) - return self.build_block_result(success=True, output_parameter_value=s3_uris) + return self.build_block_result(success=True, output_parameter_value=s3_uris, status=BlockStatus.completed) class SendEmailBlock(Block): @@ -1039,14 +1098,14 @@ class SendEmailBlock(Block): LOG.error("SendEmailBlock: Failed to send email", exc_info=True) result_dict = {"success": False, "error": str(e)} await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result_dict) - return self.build_block_result(success=False, output_parameter_value=result_dict) + return self.build_block_result(success=False, output_parameter_value=result_dict, status=BlockStatus.failed) finally: if smtp_host: smtp_host.quit() result_dict = {"success": True} await self.record_output_parameter_value(workflow_run_context, workflow_run_id, result_dict) - return self.build_block_result(success=True, output_parameter_value=result_dict) + return self.build_block_result(success=True, output_parameter_value=result_dict, status=BlockStatus.completed) class FileType(StrEnum): @@ -1109,7 +1168,7 @@ class FileParserBlock(Block): parsed_data.append(row) # Record the parsed data await self.record_output_parameter_value(workflow_run_context, workflow_run_id, parsed_data) - return self.build_block_result(success=True, output_parameter_value=parsed_data) + return self.build_block_result(success=True, output_parameter_value=parsed_data, status=BlockStatus.completed) BlockSubclasses = Union[ diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index 06c77970..2b13a534 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -62,11 +62,23 @@ class Workflow(BaseModel): class WorkflowRunStatus(StrEnum): created = "created" + queued = "queued" running = "running" failed = "failed" terminated = "terminated" + canceled = "canceled" + timed_out = "timed_out" completed = "completed" + def is_final(self) -> bool: + return self in [ + WorkflowRunStatus.failed, + WorkflowRunStatus.terminated, + WorkflowRunStatus.canceled, + WorkflowRunStatus.timed_out, + WorkflowRunStatus.completed, + ] + class WorkflowRun(BaseModel): workflow_run_id: str diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 8fcf578d..d58d4dfd 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -1,5 +1,4 @@ import json -from collections import Counter from datetime import datetime import requests @@ -13,7 +12,7 @@ from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.models import Organization, Step -from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus +from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task from skyvern.forge.sdk.workflow.exceptions import ( ContextParameterSourceNotDefined, InvalidWorkflowDefinition, @@ -21,6 +20,7 @@ from skyvern.forge.sdk.workflow.exceptions import ( WorkflowDefinitionHasReservedParameterKeys, ) from skyvern.forge.sdk.workflow.models.block import ( + BlockStatus, BlockType, BlockTypeVar, CodeBlock, @@ -191,7 +191,18 @@ class WorkflowService: block_idx=block_idx, ) block_result = await block.execute_safe(workflow_run_id=workflow_run_id) - if not block_result.success: + if block_result.status == BlockStatus.canceled: + LOG.info( + f"Block with type {block.block_type} at index {block_idx} was canceled for workflow run {workflow_run_id}, cancelling workflow run", + block_type=block.block_type, + workflow_run_id=workflow_run.workflow_run_id, + block_idx=block_idx, + block_result=block_result, + ) + await self.mark_workflow_run_as_canceled(workflow_run_id=workflow_run.workflow_run_id) + # We're not sending a webhook here because the workflow run is manually marked as canceled. + return workflow_run + elif block_result.status == BlockStatus.failed: LOG.error( f"Block with type {block.block_type} at index {block_idx} failed for workflow run {workflow_run_id}", block_type=block.block_type, @@ -216,7 +227,31 @@ class WorkflowService: api_key=api_key, ) return workflow_run - + elif block_result.status == BlockStatus.terminated: + LOG.info( + f"Block with type {block.block_type} at index {block_idx} was terminated for workflow run {workflow_run_id}, marking workflow run as terminated", + block_type=block.block_type, + workflow_run_id=workflow_run.workflow_run_id, + block_idx=block_idx, + block_result=block_result, + ) + if block.continue_on_failure: + LOG.warning( + f"Block with type {block.block_type} at index {block_idx} was terminated for workflow run {workflow_run_id}, but will continue executing the workflow run", + block_type=block.block_type, + workflow_run_id=workflow_run.workflow_run_id, + block_idx=block_idx, + block_result=block_result, + continue_on_failure=block.continue_on_failure, + ) + else: + await self.mark_workflow_run_as_terminated(workflow_run_id=workflow_run.workflow_run_id) + await self.send_workflow_response( + workflow=workflow, + workflow_run=workflow_run, + api_key=api_key, + ) + return workflow_run except Exception: LOG.exception( f"Error while executing workflow run {workflow_run.workflow_run_id}", @@ -230,54 +265,6 @@ class WorkflowService: await self.send_workflow_response(workflow=workflow, workflow_run=workflow_run, api_key=api_key) return workflow_run - async def handle_workflow_status(self, workflow_run: WorkflowRun, tasks: list[Task]) -> WorkflowRun: - task_counts_by_status = Counter(task.status for task in tasks) - - # Create a mapping of status to (action, log_func, log_message) - status_action_mapping = { - TaskStatus.running: ( - None, - LOG.error, - "has running tasks, this should not happen", - ), - TaskStatus.terminated: ( - self.mark_workflow_run_as_terminated, - LOG.warning, - "has terminated tasks, marking as terminated", - ), - TaskStatus.failed: ( - self.mark_workflow_run_as_failed, - LOG.warning, - "has failed tasks, marking as failed", - ), - TaskStatus.completed: ( - self.mark_workflow_run_as_completed, - LOG.info, - "tasks are completed, marking as completed", - ), - } - - for status, (action, log_func, log_message) in status_action_mapping.items(): - if task_counts_by_status.get(status, 0) > 0: - if action is not None: - await action(workflow_run_id=workflow_run.workflow_run_id) - if log_func and log_message: - log_func( - f"Workflow run {workflow_run.workflow_run_id} {log_message}", - workflow_run_id=workflow_run.workflow_run_id, - task_counts_by_status=task_counts_by_status, - ) - return workflow_run - - # Handle unexpected state - LOG.error( - f"Workflow run {workflow_run.workflow_run_id} has tasks in an unexpected state, marking as failed", - workflow_run_id=workflow_run.workflow_run_id, - task_counts_by_status=task_counts_by_status, - ) - await self.mark_workflow_run_as_failed(workflow_run_id=workflow_run.workflow_run_id) - return workflow_run - async def create_workflow( self, organization_id: str, @@ -459,6 +446,17 @@ class WorkflowService: status=WorkflowRunStatus.terminated, ) + async def mark_workflow_run_as_canceled(self, workflow_run_id: str) -> None: + LOG.info( + f"Marking workflow run {workflow_run_id} as canceled", + workflow_run_id=workflow_run_id, + workflow_status="canceled", + ) + await app.DATABASE.update_workflow_run( + workflow_run_id=workflow_run_id, + status=WorkflowRunStatus.canceled, + ) + async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun: workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id) if not workflow_run: