From 39c5f6233e7e1bea3d3006013d2751ede59b15e1 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Thu, 19 Sep 2024 11:15:07 -0700 Subject: [PATCH] Fix workflow reset issue upon update failure (#858) --- skyvern/forge/sdk/db/client.py | 26 +++++++++++++++---- skyvern/forge/sdk/routes/agent_protocol.py | 29 ++++++++++++++-------- skyvern/forge/sdk/workflow/exceptions.py | 16 ++++++++++++ skyvern/forge/sdk/workflow/service.py | 22 +++++++++++++++- 4 files changed, 77 insertions(+), 16 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 3df69add..9321cdcd 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -852,6 +852,23 @@ class AgentDB: await session.refresh(workflow) return convert_to_workflow(workflow, self.debug_enabled) + async def soft_delete_workflow_by_id(self, workflow_id: str, organization_id: str) -> None: + try: + async with self.Session() as session: + # soft delete the workflow by setting the deleted_at field to the current time + update_deleted_at_query = ( + update(WorkflowModel) + .where(WorkflowModel.workflow_id == workflow_id) + .where(WorkflowModel.organization_id == organization_id) + .where(WorkflowModel.deleted_at.is_(None)) + .values(deleted_at=datetime.utcnow()) + ) + await session.execute(update_deleted_at_query) + await session.commit() + except SQLAlchemyError: + LOG.error("SQLAlchemyError in soft_delete_workflow_by_id", exc_info=True) + raise + async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None: try: async with self.Session() as session: @@ -872,13 +889,12 @@ class AgentDB: workflow_permanent_id: str, organization_id: str | None = None, version: int | None = None, + exclude_deleted: bool = True, ) -> Workflow | None: try: - get_workflow_query = ( - select(WorkflowModel) - .filter_by(workflow_permanent_id=workflow_permanent_id) - .filter(WorkflowModel.deleted_at.is_(None)) - ) + get_workflow_query = select(WorkflowModel).filter_by(workflow_permanent_id=workflow_permanent_id) + if exclude_deleted: + get_workflow_query = get_workflow_query.filter(WorkflowModel.deleted_at.is_(None)) if organization_id: get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id) if version: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index acd762d9..e16dafb9 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -43,6 +43,7 @@ from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, Task from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus from skyvern.forge.sdk.services import org_auth_service from skyvern.forge.sdk.settings_manager import SettingsManager +from skyvern.forge.sdk.workflow.exceptions import FailedToCreateWorkflow, FailedToUpdateWorkflow from skyvern.forge.sdk.workflow.models.workflow import ( RunWorkflowResponse, Workflow, @@ -664,10 +665,14 @@ async def create_workflow( except yaml.YAMLError: raise HTTPException(status_code=422, detail="Invalid YAML") - workflow_create_request = WorkflowCreateYAMLRequest.model_validate(workflow_yaml) - return await app.WORKFLOW_SERVICE.create_workflow_from_request( - organization_id=current_org.organization_id, request=workflow_create_request - ) + try: + workflow_create_request = WorkflowCreateYAMLRequest.model_validate(workflow_yaml) + return await app.WORKFLOW_SERVICE.create_workflow_from_request( + organization_id=current_org.organization_id, request=workflow_create_request + ) + except Exception as e: + LOG.error("Failed to create workflow", exc_info=True) + raise FailedToCreateWorkflow(str(e)) @base_router.put( @@ -704,12 +709,16 @@ async def update_workflow( except yaml.YAMLError: raise HTTPException(status_code=422, detail="Invalid YAML") - workflow_create_request = WorkflowCreateYAMLRequest.model_validate(workflow_yaml) - return await app.WORKFLOW_SERVICE.create_workflow_from_request( - organization_id=current_org.organization_id, - request=workflow_create_request, - workflow_permanent_id=workflow_permanent_id, - ) + try: + workflow_create_request = WorkflowCreateYAMLRequest.model_validate(workflow_yaml) + return await app.WORKFLOW_SERVICE.create_workflow_from_request( + organization_id=current_org.organization_id, + request=workflow_create_request, + workflow_permanent_id=workflow_permanent_id, + ) + except Exception as e: + LOG.exception("Failed to update workflow", workflow_permanent_id=workflow_permanent_id) + raise FailedToUpdateWorkflow(workflow_permanent_id, f"<{type(e).__name__}: {str(e)}>") @base_router.delete("/workflows/{workflow_permanent_id}") diff --git a/skyvern/forge/sdk/workflow/exceptions.py b/skyvern/forge/sdk/workflow/exceptions.py index f234dab6..f9e64def 100644 --- a/skyvern/forge/sdk/workflow/exceptions.py +++ b/skyvern/forge/sdk/workflow/exceptions.py @@ -20,6 +20,22 @@ class WorkflowDefinitionHasDuplicateBlockLabels(BaseWorkflowHTTPException): ) +class FailedToCreateWorkflow(BaseWorkflowHTTPException): + def __init__(self, error_message: str) -> None: + super().__init__( + f"Failed to create workflow. Error: {error_message}", + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +class FailedToUpdateWorkflow(BaseWorkflowHTTPException): + def __init__(self, workflow_permanent_id: str, error_message: str) -> None: + super().__init__( + f"Failed to update workflow with ID {workflow_permanent_id}. Error: {error_message}", + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + class OutputParameterKeyCollisionError(BaseWorkflowHTTPException): def __init__(self, key: str, retry_count: int | None = None) -> None: message = f"Output parameter key {key} already exists in the context manager." diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 3e95f6fd..ceeb040b 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -318,11 +318,13 @@ class WorkflowService: workflow_permanent_id: str, organization_id: str | None = None, version: int | None = None, + exclude_deleted: bool = True, ) -> Workflow: workflow = await app.DATABASE.get_workflow_by_permanent_id( workflow_permanent_id, organization_id=organization_id, version=version, + exclude_deleted=exclude_deleted, ) if not workflow: raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id, version=version) @@ -376,6 +378,16 @@ class WorkflowService: organization_id=organization_id, ) + async def delete_workflow_by_id( + self, + workflow_id: str, + organization_id: str, + ) -> None: + await app.DATABASE.soft_delete_workflow_by_id( + workflow_id=workflow_id, + organization_id=organization_id, + ) + async def get_workflow_runs(self, organization_id: str, page: int = 1, page_size: int = 10) -> list[WorkflowRun]: return await app.DATABASE.get_workflow_runs(organization_id=organization_id, page=page, page_size=page_size) @@ -820,11 +832,13 @@ class WorkflowService: organization_id=organization_id, title=request.title, ) + new_workflow_id: str | None = None try: if workflow_permanent_id: existing_latest_workflow = await self.get_workflow_by_permanent_id( workflow_permanent_id=workflow_permanent_id, organization_id=organization_id, + exclude_deleted=False, ) existing_version = existing_latest_workflow.version workflow = await self.create_workflow( @@ -854,6 +868,8 @@ class WorkflowService: persist_browser_session=request.persist_browser_session, is_saved_task=request.is_saved_task, ) + # Keeping track of the new workflow id to delete it if an error occurs during the creation process + new_workflow_id = workflow.workflow_id # Create parameters from the request parameters: dict[str, PARAMETER_TYPE] = {} duplicate_parameter_keys = set() @@ -991,7 +1007,11 @@ class WorkflowService: ) return workflow except Exception as e: - LOG.exception(f"Failed to create workflow from request, title: {request.title}") + if new_workflow_id: + LOG.error(f"Failed to create workflow from request, deleting workflow {new_workflow_id}") + await self.delete_workflow_by_id(workflow_id=new_workflow_id, organization_id=organization_id) + else: + LOG.exception(f"Failed to create workflow from request, title: {request.title}") raise e @staticmethod