get_total_unique_step_order_count_by_task_ids (#1880)

This commit is contained in:
Shuchang Zheng
2025-03-04 02:04:18 -05:00
committed by GitHub
parent da78fb2edb
commit 2cef654a9a
2 changed files with 12 additions and 9 deletions

View File

@@ -3,7 +3,7 @@ from datetime import datetime, timedelta
from typing import Any, List, Optional, Sequence
import structlog
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy import and_, delete, distinct, func, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
@@ -343,18 +343,23 @@ class AgentDB:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_total_step_count_by_task_ids(
self, task_ids: list[str], organization_id: str | None = None, statuses: list[StepStatus] | None = None
async def get_total_unique_step_order_count_by_task_ids(
self,
task_ids: list[str],
organization_id: str | None = None,
) -> int:
"""
Get the total count of unique (step.task_id, step.order) pairs of StepModel for the given task ids
Basically translate this sql query into a SQLAlchemy query: select count(distinct(s.task_id, s.order)) from steps s
where s.task_id in task_ids
"""
try:
async with self.Session() as session:
query = (
select(func.count())
select(func.count(distinct(StepModel.task_id, StepModel.order)))
.where(StepModel.task_id.in_(task_ids))
.filter_by(organization_id=organization_id)
)
if statuses:
query = query.filter(StepModel.status.in_(statuses))
return (await session.scalars(query)).scalar()
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)