workflow runtime API (#1421)

This commit is contained in:
Shuchang Zheng
2024-12-22 20:54:53 -08:00
committed by GitHub
parent 2e37542218
commit 94a3779bd7
5 changed files with 137 additions and 79 deletions

View File

@@ -250,6 +250,28 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_tasks_by_ids(
self,
task_ids: list[str],
organization_id: str | None = None,
) -> list[Task]:
try:
async with self.Session() as session:
tasks = (
await session.scalars(
select(TaskModel)
.filter(TaskModel.task_id.in_(task_ids))
.filter_by(organization_id=organization_id)
)
).all()
return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None:
try:
async with self.Session() as session:
@@ -1883,7 +1905,7 @@ class AgentDB:
return ObserverThought.model_validate(observer_thought)
return None
async def get_observer_cruise_thoughts(
async def get_observer_thoughts(
self,
observer_cruise_id: str,
organization_id: str | None = None,
@@ -2079,3 +2101,24 @@ class AgentDB:
task = await self.get_task(task_id, organization_id=organization_id)
return convert_to_workflow_run_block(workflow_run_block, task=task)
raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found")
async def get_workflow_run_blocks(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> list[WorkflowRunBlock]:
async with self.Session() as session:
workflow_run_blocks = (
await session.scalars(
select(WorkflowRunBlockModel)
.filter_by(workflow_run_id=workflow_run_id)
.filter_by(organization_id=organization_id)
.order_by(WorkflowRunBlockModel.created_at)
)
).all()
tasks = await self.get_tasks_by_workflow_run_id(workflow_run_id)
tasks_dict = {task.task_id: task for task in tasks}
return [
convert_to_workflow_run_block(workflow_run_block, task=tasks_dict.get(workflow_run_block.task_id))
for workflow_run_block in workflow_run_blocks
]