From 08ca5a0b454a3e7224a1f2fa97998ed3f8ae13ce Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Sat, 20 Dec 2025 03:18:50 +0800 Subject: [PATCH] batch task/workflow update (#4344) --- skyvern/forge/sdk/db/agent_db.py | 60 ++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/skyvern/forge/sdk/db/agent_db.py b/skyvern/forge/sdk/db/agent_db.py index 1405a1a8..a507015d 100644 --- a/skyvern/forge/sdk/db/agent_db.py +++ b/skyvern/forge/sdk/db/agent_db.py @@ -770,6 +770,34 @@ class AgentDB(BaseAlchemyDB): LOG.error("UnexpectedError", exc_info=True) raise + async def bulk_update_tasks( + self, + task_ids: list[str], + status: TaskStatus | None = None, + failure_reason: str | None = None, + ) -> None: + """Bulk update tasks by their IDs. + + Args: + task_ids: List of task IDs to update + status: Optional status to set for all tasks + failure_reason: Optional failure reason to set for all tasks + """ + if not task_ids: + return + + async with self.Session() as session: + update_values = {} + if status: + update_values["status"] = status.value + if failure_reason: + update_values["failure_reason"] = failure_reason + + if update_values: + update_stmt = update(TaskModel).where(TaskModel.task_id.in_(task_ids)).values(**update_values) + await session.execute(update_stmt) + await session.commit() + async def get_tasks( self, page: int = 1, @@ -2609,6 +2637,38 @@ class AgentDB(BaseAlchemyDB): else: raise WorkflowRunNotFound(workflow_run_id) + async def bulk_update_workflow_runs( + self, + workflow_run_ids: list[str], + status: WorkflowRunStatus | None = None, + failure_reason: str | None = None, + ) -> None: + """Bulk update workflow runs by their IDs. + + Args: + workflow_run_ids: List of workflow run IDs to update + status: Optional status to set for all workflow runs + failure_reason: Optional failure reason to set for all workflow runs + """ + if not workflow_run_ids: + return + + async with self.Session() as session: + update_values = {} + if status: + update_values["status"] = status.value + if failure_reason: + update_values["failure_reason"] = failure_reason + + if update_values: + update_stmt = ( + update(WorkflowRunModel) + .where(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids)) + .values(**update_values) + ) + await session.execute(update_stmt) + await session.commit() + async def clear_workflow_run_failure_reason(self, workflow_run_id: str, organization_id: str) -> WorkflowRun: async with self.Session() as session: workflow_run = (