batch task/workflow update (#4344)
This commit is contained in:
@@ -770,6 +770,34 @@ class AgentDB(BaseAlchemyDB):
|
|||||||
LOG.error("UnexpectedError", exc_info=True)
|
LOG.error("UnexpectedError", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def bulk_update_tasks(
|
||||||
|
self,
|
||||||
|
task_ids: list[str],
|
||||||
|
status: TaskStatus | None = None,
|
||||||
|
failure_reason: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Bulk update tasks by their IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_ids: List of task IDs to update
|
||||||
|
status: Optional status to set for all tasks
|
||||||
|
failure_reason: Optional failure reason to set for all tasks
|
||||||
|
"""
|
||||||
|
if not task_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.Session() as session:
|
||||||
|
update_values = {}
|
||||||
|
if status:
|
||||||
|
update_values["status"] = status.value
|
||||||
|
if failure_reason:
|
||||||
|
update_values["failure_reason"] = failure_reason
|
||||||
|
|
||||||
|
if update_values:
|
||||||
|
update_stmt = update(TaskModel).where(TaskModel.task_id.in_(task_ids)).values(**update_values)
|
||||||
|
await session.execute(update_stmt)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def get_tasks(
|
async def get_tasks(
|
||||||
self,
|
self,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
@@ -2609,6 +2637,38 @@ class AgentDB(BaseAlchemyDB):
|
|||||||
else:
|
else:
|
||||||
raise WorkflowRunNotFound(workflow_run_id)
|
raise WorkflowRunNotFound(workflow_run_id)
|
||||||
|
|
||||||
|
async def bulk_update_workflow_runs(
|
||||||
|
self,
|
||||||
|
workflow_run_ids: list[str],
|
||||||
|
status: WorkflowRunStatus | None = None,
|
||||||
|
failure_reason: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Bulk update workflow runs by their IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_ids: List of workflow run IDs to update
|
||||||
|
status: Optional status to set for all workflow runs
|
||||||
|
failure_reason: Optional failure reason to set for all workflow runs
|
||||||
|
"""
|
||||||
|
if not workflow_run_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.Session() as session:
|
||||||
|
update_values = {}
|
||||||
|
if status:
|
||||||
|
update_values["status"] = status.value
|
||||||
|
if failure_reason:
|
||||||
|
update_values["failure_reason"] = failure_reason
|
||||||
|
|
||||||
|
if update_values:
|
||||||
|
update_stmt = (
|
||||||
|
update(WorkflowRunModel)
|
||||||
|
.where(WorkflowRunModel.workflow_run_id.in_(workflow_run_ids))
|
||||||
|
.values(**update_values)
|
||||||
|
)
|
||||||
|
await session.execute(update_stmt)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def clear_workflow_run_failure_reason(self, workflow_run_id: str, organization_id: str) -> WorkflowRun:
|
async def clear_workflow_run_failure_reason(self, workflow_run_id: str, organization_id: str) -> WorkflowRun:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
workflow_run = (
|
workflow_run = (
|
||||||
|
|||||||
Reference in New Issue
Block a user