From e0e868445d6049f2b1eb62a3486786564be8f03b Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 4 Feb 2025 03:59:10 +0800 Subject: [PATCH] Implement a runs endpoint that can return workflow runs or tasks (#1708) --- skyvern/forge/sdk/db/client.py | 42 ++++++++++++++++++++++ skyvern/forge/sdk/routes/agent_protocol.py | 18 ++++++++++ 2 files changed, 60 insertions(+) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 16d6c51f..254c2040 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1387,6 +1387,48 @@ class AgentDB: ) return None + async def get_all_runs( + self, organization_id: str, page: int = 1, page_size: int = 10, status: list[WorkflowRunStatus] | None = None + ) -> list[WorkflowRun | Task]: + try: + async with self.Session() as session: + # temporary limit to 10 pages + if page > 10: + return [] + + limit = page * page_size + + workflow_run_query = select(WorkflowRunModel).filter( + WorkflowRunModel.organization_id == organization_id + ) + if status: + workflow_run_query = workflow_run_query.filter(WorkflowRunModel.status.in_(status)) + workflow_run_query = workflow_run_query.order_by(WorkflowRunModel.created_at.desc()).limit(limit) + workflow_run_query_result = (await session.scalars(workflow_run_query)).all() + workflow_runs = [ + convert_to_workflow_run(run, debug_enabled=self.debug_enabled) for run in workflow_run_query_result + ] + + task_query = select(TaskModel).filter(TaskModel.organization_id == organization_id) + if status: + task_query = task_query.filter(TaskModel.status.in_(status)) + task_query = task_query.order_by(TaskModel.created_at.desc()).limit(limit) + task_query_result = (await session.scalars(task_query)).all() + tasks = [convert_to_task(task, debug_enabled=self.debug_enabled) for task in task_query_result] + + runs = workflow_runs + tasks + + runs.sort(key=lambda x: x.created_at, reverse=True) + + lower = (page - 1) * page_size + upper = page * page_size + + return runs[lower:upper] + + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun | None: try: async with self.Session() as session: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index d8073115..01584c6b 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -435,6 +435,24 @@ async def get_agent_tasks( return ORJSONResponse([(await app.agent.build_task_response(task=task)).model_dump() for task in tasks]) +@base_router.get("/runs", response_model=list[WorkflowRun | Task]) +@base_router.get("/runs/", response_model=list[WorkflowRun | Task], include_in_schema=False) +async def get_runs( + current_org: Organization = Depends(org_auth_service.get_current_org), + page: int = Query(1, ge=1), + page_size: int = Query(10, ge=1), + status: Annotated[list[WorkflowRunStatus] | None, Query()] = None, +) -> Response: + analytics.capture("skyvern-oss-agent-runs-get") + + # temporary limit to 100 runs + if page > 10: + return [] + + runs = await app.DATABASE.get_all_runs(current_org.organization_id, page=page, page_size=page_size, status=status) + return ORJSONResponse([run.model_dump() for run in runs]) + + @base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task]) @base_router.get( "/internal/tasks/",