From 1d1a4d72ea5c6ccbdbde847bdf4c456672fd7f73 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 12 May 2025 08:30:37 -0700 Subject: [PATCH] =?UTF-8?q?backend:=20normalize=20returns=20for=20evals=20?= =?UTF-8?q?(items,=20total);=20ensure=20title=20to=20=E2=80=A6=20(#2329)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- skyvern/forge/sdk/db/client.py | 63 +++++++++++++++++++++++++-- skyvern/forge/sdk/db/utils.py | 3 +- skyvern/forge/sdk/schemas/tasks.py | 1 + skyvern/forge/sdk/workflow/service.py | 10 +++++ 4 files changed, 73 insertions(+), 4 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index d3e7cef2..8ccc902a 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -650,7 +650,11 @@ class AgentDB: try: async with self.Session() as session: db_page = page - 1 # offset logic is 0 based - query = select(TaskModel).filter(TaskModel.organization_id == organization_id) + query = ( + select(TaskModel, WorkflowRunModel.workflow_permanent_id) + .join(WorkflowRunModel, TaskModel.workflow_run_id == WorkflowRunModel.workflow_run_id, isouter=True) + .filter(TaskModel.organization_id == organization_id) + ) if task_status: query = query.filter(TaskModel.status.in_(task_status)) if workflow_run_id: @@ -665,8 +669,42 @@ class AgentDB: .limit(page_size) .offset(db_page * page_size) ) - tasks = (await session.scalars(query)).all() - return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks] + + results = (await session.execute(query)).all() + + return [ + convert_to_task(task, debug_enabled=self.debug_enabled, workflow_permanent_id=workflow_permanent_id) + for task, workflow_permanent_id in results + ] + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + except Exception: + LOG.error("UnexpectedError", exc_info=True) + raise + + async def get_tasks_count( + self, + organization_id: str, + task_status: list[TaskStatus] | None = None, + workflow_run_id: str | None = None, + only_standalone_tasks: bool = False, + application: str | None = None, + ) -> int: + try: + async with self.Session() as session: + count_query = ( + select(func.count()).select_from(TaskModel).filter(TaskModel.organization_id == organization_id) + ) + if task_status: + count_query = count_query.filter(TaskModel.status.in_(task_status)) + if workflow_run_id: + count_query = count_query.filter(TaskModel.workflow_run_id == workflow_run_id) + if only_standalone_tasks: + count_query = count_query.filter(TaskModel.workflow_run_id.is_(None)) + if application: + count_query = count_query.filter(TaskModel.application == application) + return (await session.execute(count_query)).scalar_one() except SQLAlchemyError: LOG.error("SQLAlchemyError", exc_info=True) raise @@ -1527,6 +1565,25 @@ class AgentDB: LOG.error("SQLAlchemyError", exc_info=True) raise + async def get_workflow_runs_count( + self, + organization_id: str, + status: list[WorkflowRunStatus] | None = None, + ) -> int: + try: + async with self.Session() as session: + count_query = ( + select(func.count()) + .select_from(WorkflowRunModel) + .filter(WorkflowRunModel.organization_id == organization_id) + ) + if status: + count_query = count_query.filter(WorkflowRunModel.status.in_(status)) + return (await session.execute(count_query)).scalar_one() + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + async def get_workflow_runs_for_workflow_permanent_id( self, workflow_permanent_id: str, diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 7c331644..590a233b 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -58,7 +58,7 @@ def _custom_json_serializer(*args, **kwargs) -> str: return json.dumps(*args, default=pydantic.json.pydantic_encoder, **kwargs) -def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task: +def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_permanent_id: str | None = None) -> Task: if debug_enabled: LOG.debug("Converting TaskModel to Task", task_id=task_obj.task_id) task = Task( @@ -83,6 +83,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task: proxy_location=(ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None), extracted_information_schema=task_obj.extracted_information_schema, workflow_run_id=task_obj.workflow_run_id, + workflow_permanent_id=workflow_permanent_id, order=task_obj.order, retry=task_obj.retry, max_steps_per_run=task_obj.max_steps_per_run, diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 724f3e98..e09b63c8 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -226,6 +226,7 @@ class Task(TaskBase): ) organization_id: str | None = None workflow_run_id: str | None = None + workflow_permanent_id: str | None = None order: int | None = None retry: int | None = None max_steps_per_run: int | None = None diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 0c6783a5..fdd7da1c 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -677,6 +677,16 @@ class WorkflowService: organization_id=organization_id, page=page, page_size=page_size, status=status ) + async def get_workflow_runs_count( + self, + organization_id: str, + status: list[WorkflowRunStatus] | None = None, + ) -> int: + return await app.DATABASE.get_workflow_runs_count( + organization_id=organization_id, + status=status, + ) + async def get_workflow_runs_for_workflow_permanent_id( self, workflow_permanent_id: str,