From 5c37ebbb9e9c7dd298394e658449ac8afc809254 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Fri, 24 Jan 2025 23:31:26 +0800 Subject: [PATCH] Add status filter to workflow runs endpoints (#1637) --- skyvern/forge/sdk/db/client.py | 44 +++++++++++----------- skyvern/forge/sdk/routes/agent_protocol.py | 5 +++ skyvern/forge/sdk/workflow/service.py | 16 ++++++-- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index bc3788c8..67df35be 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1340,40 +1340,42 @@ class AgentDB: LOG.error("SQLAlchemyError", exc_info=True) raise - async def get_workflow_runs(self, organization_id: str, page: int = 1, page_size: int = 10) -> list[WorkflowRun]: + async def get_workflow_runs( + self, organization_id: str, page: int = 1, page_size: int = 10, status: list[WorkflowRunStatus] | None = None + ) -> list[WorkflowRun]: try: async with self.Session() as session: db_page = page - 1 # offset logic is 0 based - workflow_runs = ( - await session.scalars( - select(WorkflowRunModel) - .filter(WorkflowRunModel.organization_id == organization_id) - .order_by(WorkflowRunModel.created_at.desc()) - .limit(page_size) - .offset(db_page * page_size) - ) - ).all() + query = select(WorkflowRunModel).filter(WorkflowRunModel.organization_id == organization_id) + if status: + query = query.filter(WorkflowRunModel.status.in_(status)) + query = query.order_by(WorkflowRunModel.created_at.desc()).limit(page_size).offset(db_page * page_size) + workflow_runs = (await session.scalars(query)).all() return [convert_to_workflow_run(run) for run in workflow_runs] except SQLAlchemyError: LOG.error("SQLAlchemyError", exc_info=True) raise async def get_workflow_runs_for_workflow_permanent_id( - self, workflow_permanent_id: str, organization_id: str, page: int = 1, page_size: int = 10 + self, + workflow_permanent_id: str, + organization_id: str, + page: int = 1, + page_size: int = 10, + status: list[WorkflowRunStatus] | None = None, ) -> list[WorkflowRun]: try: async with self.Session() as session: db_page = page - 1 # offset logic is 0 based - workflow_runs = ( - await session.scalars( - select(WorkflowRunModel) - .filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id) - .filter(WorkflowRunModel.organization_id == organization_id) - .order_by(WorkflowRunModel.created_at.desc()) - .limit(page_size) - .offset(db_page * page_size) - ) - ).all() + query = ( + select(WorkflowRunModel) + .filter(WorkflowRunModel.workflow_permanent_id == workflow_permanent_id) + .filter(WorkflowRunModel.organization_id == organization_id) + ) + if status: + query = query.filter(WorkflowRunModel.status.in_(status)) + query = query.order_by(WorkflowRunModel.created_at.desc()).limit(page_size).offset(db_page * page_size) + workflow_runs = (await session.scalars(query)).all() return [convert_to_workflow_run(run) for run in workflow_runs] except SQLAlchemyError: LOG.error("SQLAlchemyError", exc_info=True) diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 5d491e19..d6c4e836 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -67,6 +67,7 @@ from skyvern.forge.sdk.workflow.models.workflow import ( Workflow, WorkflowRequestBody, WorkflowRun, + WorkflowRunStatus, WorkflowRunStatusResponse, ) from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest @@ -677,6 +678,7 @@ async def execute_workflow( async def get_workflow_runs( page: int = Query(1, ge=1), page_size: int = Query(10, ge=1), + status: Annotated[list[WorkflowRunStatus] | None, Query()] = None, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> list[WorkflowRun]: analytics.capture("skyvern-oss-agent-workflow-runs-get") @@ -684,6 +686,7 @@ async def get_workflow_runs( organization_id=current_org.organization_id, page=page, page_size=page_size, + status=status, ) @@ -700,6 +703,7 @@ async def get_workflow_runs_for_workflow_permanent_id( workflow_permanent_id: str, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1), + status: Annotated[list[WorkflowRunStatus] | None, Query()] = None, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> list[WorkflowRun]: analytics.capture("skyvern-oss-agent-workflow-runs-get") @@ -708,6 +712,7 @@ async def get_workflow_runs_for_workflow_permanent_id( organization_id=current_org.organization_id, page=page, page_size=page_size, + status=status, ) diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 3c2c5f18..c761760c 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -575,17 +575,27 @@ class WorkflowService: organization_id=organization_id, ) - async def get_workflow_runs(self, organization_id: str, page: int = 1, page_size: int = 10) -> list[WorkflowRun]: - return await app.DATABASE.get_workflow_runs(organization_id=organization_id, page=page, page_size=page_size) + async def get_workflow_runs( + self, organization_id: str, page: int = 1, page_size: int = 10, status: list[WorkflowRunStatus] | None = None + ) -> list[WorkflowRun]: + return await app.DATABASE.get_workflow_runs( + organization_id=organization_id, page=page, page_size=page_size, status=status + ) async def get_workflow_runs_for_workflow_permanent_id( - self, workflow_permanent_id: str, organization_id: str, page: int = 1, page_size: int = 10 + self, + workflow_permanent_id: str, + organization_id: str, + page: int = 1, + page_size: int = 10, + status: list[WorkflowRunStatus] | None = None, ) -> list[WorkflowRun]: return await app.DATABASE.get_workflow_runs_for_workflow_permanent_id( workflow_permanent_id=workflow_permanent_id, organization_id=organization_id, page=page, page_size=page_size, + status=status, ) async def create_workflow_run(