From 2cef654a9a426f36a645e451bc938e3e619361e4 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 4 Mar 2025 02:04:18 -0500 Subject: [PATCH] get_total_unique_step_order_count_by_task_ids (#1880) --- skyvern/forge/sdk/db/client.py | 17 +++++++++++------ skyvern/forge/sdk/services/task_v2_service.py | 4 +--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index a5a9e24c..75b3fecd 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from typing import Any, List, Optional, Sequence import structlog -from sqlalchemy import and_, delete, func, select, update +from sqlalchemy import and_, delete, distinct, func, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine @@ -343,18 +343,23 @@ class AgentDB: LOG.error("SQLAlchemyError", exc_info=True) raise - async def get_total_step_count_by_task_ids( - self, task_ids: list[str], organization_id: str | None = None, statuses: list[StepStatus] | None = None + async def get_total_unique_step_order_count_by_task_ids( + self, + task_ids: list[str], + organization_id: str | None = None, ) -> int: + """ + Get the total count of unique (step.task_id, step.order) pairs of StepModel for the given task ids + Basically translate this sql query into a SQLAlchemy query: select count(distinct(s.task_id, s.order)) from steps s + where s.task_id in task_ids + """ try: async with self.Session() as session: query = ( - select(func.count()) + select(func.count(distinct(StepModel.task_id, StepModel.order))) .where(StepModel.task_id.in_(task_ids)) .filter_by(organization_id=organization_id) ) - if statuses: - query = query.filter(StepModel.status.in_(statuses)) return (await session.scalars(query)).scalar() except SQLAlchemyError: LOG.error("SQLAlchemyError", exc_info=True) diff --git a/skyvern/forge/sdk/services/task_v2_service.py b/skyvern/forge/sdk/services/task_v2_service.py index 0316df69..13ec5f30 100644 --- a/skyvern/forge/sdk/services/task_v2_service.py +++ b/skyvern/forge/sdk/services/task_v2_service.py @@ -18,7 +18,6 @@ from skyvern.forge.sdk.core.hashing import generate_url_hash from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType -from skyvern.forge.sdk.models import StepStatus from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.task_runs import TaskRunType from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType @@ -709,10 +708,9 @@ async def run_task_v2_helper( # total step number validation workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id) - total_step_count = await app.DATABASE.get_total_step_count_by_task_ids( + total_step_count = await app.DATABASE.get_total_unique_step_order_count_by_task_ids( task_ids=[task.task_id for task in workflow_run_tasks], organization_id=organization_id, - statuses=[StepStatus.completed], ) if total_step_count >= max_steps: LOG.info("Task v2 failed - run out of steps", max_steps=max_steps, workflow_run_id=workflow_run_id)