Implement a runs endpoint that can return workflow runs or tasks (#1708)
This commit is contained in:
@@ -1387,6 +1387,48 @@ class AgentDB:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_all_runs(
|
||||||
|
self, organization_id: str, page: int = 1, page_size: int = 10, status: list[WorkflowRunStatus] | None = None
|
||||||
|
) -> list[WorkflowRun | Task]:
|
||||||
|
try:
|
||||||
|
async with self.Session() as session:
|
||||||
|
# temporary limit to 10 pages
|
||||||
|
if page > 10:
|
||||||
|
return []
|
||||||
|
|
||||||
|
limit = page * page_size
|
||||||
|
|
||||||
|
workflow_run_query = select(WorkflowRunModel).filter(
|
||||||
|
WorkflowRunModel.organization_id == organization_id
|
||||||
|
)
|
||||||
|
if status:
|
||||||
|
workflow_run_query = workflow_run_query.filter(WorkflowRunModel.status.in_(status))
|
||||||
|
workflow_run_query = workflow_run_query.order_by(WorkflowRunModel.created_at.desc()).limit(limit)
|
||||||
|
workflow_run_query_result = (await session.scalars(workflow_run_query)).all()
|
||||||
|
workflow_runs = [
|
||||||
|
convert_to_workflow_run(run, debug_enabled=self.debug_enabled) for run in workflow_run_query_result
|
||||||
|
]
|
||||||
|
|
||||||
|
task_query = select(TaskModel).filter(TaskModel.organization_id == organization_id)
|
||||||
|
if status:
|
||||||
|
task_query = task_query.filter(TaskModel.status.in_(status))
|
||||||
|
task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit)
|
||||||
|
task_query_result = (await session.scalars(task_query)).all()
|
||||||
|
tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result]
|
||||||
|
|
||||||
|
runs = workflow_runs + tasks
|
||||||
|
|
||||||
|
runs.sort(key=lambda x: x.created_at, reverse=True)
|
||||||
|
|
||||||
|
lower = (page - 1) * page_size
|
||||||
|
upper = page * page_size
|
||||||
|
|
||||||
|
return runs[lower:upper]
|
||||||
|
|
||||||
|
except SQLAlchemyError:
|
||||||
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun | None:
|
async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun | None:
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
|
|||||||
@@ -435,6 +435,24 @@ async def get_agent_tasks(
|
|||||||
return ORJSONResponse([(await app.agent.build_task_response(task=task)).model_dump() for task in tasks])
|
return ORJSONResponse([(await app.agent.build_task_response(task=task)).model_dump() for task in tasks])
|
||||||
|
|
||||||
|
|
||||||
|
@base_router.get("/runs", response_model=list[WorkflowRun | Task])
|
||||||
|
@base_router.get("/runs/", response_model=list[WorkflowRun | Task], include_in_schema=False)
|
||||||
|
async def get_runs(
|
||||||
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(10, ge=1),
|
||||||
|
status: Annotated[list[WorkflowRunStatus] | None, Query()] = None,
|
||||||
|
) -> Response:
|
||||||
|
analytics.capture("skyvern-oss-agent-runs-get")
|
||||||
|
|
||||||
|
# temporary limit to 100 runs
|
||||||
|
if page > 10:
|
||||||
|
return []
|
||||||
|
|
||||||
|
runs = await app.DATABASE.get_all_runs(current_org.organization_id, page=page, page_size=page_size, status=status)
|
||||||
|
return ORJSONResponse([run.model_dump() for run in runs])
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
|
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
|
||||||
@base_router.get(
|
@base_router.get(
|
||||||
"/internal/tasks/",
|
"/internal/tasks/",
|
||||||
|
|||||||
Reference in New Issue
Block a user