diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index a82f241d..e472b8ed 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta, timezone from typing import Any, List, Sequence import structlog -from sqlalchemy import and_, delete, distinct, func, or_, pool, select, tuple_, update +from sqlalchemy import and_, asc, delete, distinct, func, or_, pool, select, tuple_, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine @@ -2404,15 +2404,20 @@ class AgentDB: - organization_id - totp_identifier 2. make sure created_at is within the valid lifespan - 3. sort by created_at desc + 3. sort by task_id/workflow_id/workflow_run_id nullslast and created_at desc """ + all_null = and_( + TOTPCodeModel.task_id.is_(None), + TOTPCodeModel.workflow_id.is_(None), + TOTPCodeModel.workflow_run_id.is_(None), + ) async with self.Session() as session: query = ( select(TOTPCodeModel) .filter_by(organization_id=organization_id) .filter_by(totp_identifier=totp_identifier) .filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes)) - .order_by(TOTPCodeModel.created_at.desc()) + .order_by(asc(all_null), TOTPCodeModel.created_at.desc()) ) totp_code = (await session.scalars(query)).all() return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]