From b0d9f9ce5fc763f1839370740ad4085313406c27 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 21 Oct 2024 10:34:42 -0700 Subject: [PATCH] Add sorting to task api (#1018) --- skyvern/forge/sdk/db/client.py | 13 +++++++++++-- skyvern/forge/sdk/routes/agent_protocol.py | 16 +++++++++++++++- skyvern/forge/sdk/schemas/tasks.py | 10 ++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 4f86dddc..4e605b29 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -51,7 +51,7 @@ from skyvern.forge.sdk.db.utils import ( ) from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus from skyvern.forge.sdk.schemas.task_generations import TaskGeneration -from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus +from skyvern.forge.sdk.schemas.tasks import OrderBy, ProxyLocation, SortDirection, Task, TaskStatus from skyvern.forge.sdk.schemas.totp_codes import TOTPCode from skyvern.forge.sdk.workflow.models.parameter import ( AWSSecretParameter, @@ -461,6 +461,8 @@ class AgentDB: workflow_run_id: str | None = None, organization_id: str | None = None, only_standalone_tasks: bool = False, + order_by_column: OrderBy = OrderBy.created_at, + order: SortDirection = SortDirection.desc, ) -> list[Task]: """ Get all tasks. @@ -469,6 +471,8 @@ class AgentDB: :param task_status: :param workflow_run_id: :param only_standalone_tasks: + :param order_by_column: + :param order: :return: """ if page < 1: @@ -484,7 +488,12 @@ class AgentDB: query = query.filter(TaskModel.workflow_run_id == workflow_run_id) if only_standalone_tasks: query = query.filter(TaskModel.workflow_run_id.is_(None)) - query = query.order_by(TaskModel.created_at.desc()).limit(page_size).offset(db_page * page_size) + order_by_col = getattr(TaskModel, order_by_column) + query = ( + query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc()) + .limit(page_size) + .offset(db_page * page_size) + ) tasks = (await session.scalars(query)).all() return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks] except SQLAlchemyError: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 07d6b7b0..fdd2eb9f 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -40,7 +40,15 @@ from skyvern.forge.sdk.schemas.organizations import ( OrganizationUpdate, ) from skyvern.forge.sdk.schemas.task_generations import GenerateTaskRequest, TaskGeneration, TaskGenerationBase -from skyvern.forge.sdk.schemas.tasks import CreateTaskResponse, Task, TaskRequest, TaskResponse, TaskStatus +from skyvern.forge.sdk.schemas.tasks import ( + CreateTaskResponse, + OrderBy, + SortDirection, + Task, + TaskRequest, + TaskResponse, + TaskStatus, +) from skyvern.forge.sdk.services import org_auth_service from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.workflow.exceptions import FailedToCreateWorkflow, FailedToUpdateWorkflow @@ -385,6 +393,8 @@ async def get_agent_tasks( workflow_run_id: Annotated[str | None, Query()] = None, current_org: Organization = Depends(org_auth_service.get_current_org), only_standalone_tasks: bool = Query(False), + sort: OrderBy = Query(OrderBy.created_at), + order: SortDirection = Query(SortDirection.desc), ) -> Response: """ Get all tasks. @@ -393,6 +403,8 @@ async def get_agent_tasks( :param task_status: Task status filter :param workflow_run_id: Workflow run id filter :param only_standalone_tasks: Only standalone tasks, tasks which are part of a workflow run will be filtered out + :param order: Direction to sort by, ascending or descending + :param sort: Column to sort by, created_at or modified_at :return: List of tasks with pagination without steps populated. Steps can be populated by calling the get_agent_task endpoint. """ @@ -409,6 +421,8 @@ async def get_agent_tasks( workflow_run_id=workflow_run_id, organization_id=current_org.organization_id, only_standalone_tasks=only_standalone_tasks, + order=order, + order_by_column=sort, ) return ORJSONResponse([task.to_task_response().model_dump() for task in tasks]) diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index df0b5104..709860aa 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -290,3 +290,13 @@ class TaskOutput(BaseModel): class CreateTaskResponse(BaseModel): task_id: str + + +class OrderBy(StrEnum): + created_at = "created_at" + modified_at = "modified_at" + + +class SortDirection(StrEnum): + asc = "asc" + desc = "desc"