get_total_unique_step_order_count_by_task_ids (#1880)
This commit is contained in:
@@ -3,7 +3,7 @@ from datetime import datetime, timedelta
|
|||||||
from typing import Any, List, Optional, Sequence
|
from typing import Any, List, Optional, Sequence
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from sqlalchemy import and_, delete, func, select, update
|
from sqlalchemy import and_, delete, distinct, func, select, update
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
@@ -343,18 +343,23 @@ class AgentDB:
|
|||||||
LOG.error("SQLAlchemyError", exc_info=True)
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_total_step_count_by_task_ids(
|
async def get_total_unique_step_order_count_by_task_ids(
|
||||||
self, task_ids: list[str], organization_id: str | None = None, statuses: list[StepStatus] | None = None
|
self,
|
||||||
|
task_ids: list[str],
|
||||||
|
organization_id: str | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get the total count of unique (step.task_id, step.order) pairs of StepModel for the given task ids
|
||||||
|
Basically translate this sql query into a SQLAlchemy query: select count(distinct(s.task_id, s.order)) from steps s
|
||||||
|
where s.task_id in task_ids
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
query = (
|
query = (
|
||||||
select(func.count())
|
select(func.count(distinct(StepModel.task_id, StepModel.order)))
|
||||||
.where(StepModel.task_id.in_(task_ids))
|
.where(StepModel.task_id.in_(task_ids))
|
||||||
.filter_by(organization_id=organization_id)
|
.filter_by(organization_id=organization_id)
|
||||||
)
|
)
|
||||||
if statuses:
|
|
||||||
query = query.filter(StepModel.status.in_(statuses))
|
|
||||||
return (await session.scalars(query)).scalar()
|
return (await session.scalars(query)).scalar()
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
LOG.error("SQLAlchemyError", exc_info=True)
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from skyvern.forge.sdk.core.hashing import generate_url_hash
|
|||||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||||
from skyvern.forge.sdk.models import StepStatus
|
|
||||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
||||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType
|
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType
|
||||||
@@ -709,10 +708,9 @@ async def run_task_v2_helper(
|
|||||||
|
|
||||||
# total step number validation
|
# total step number validation
|
||||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||||
total_step_count = await app.DATABASE.get_total_step_count_by_task_ids(
|
total_step_count = await app.DATABASE.get_total_unique_step_order_count_by_task_ids(
|
||||||
task_ids=[task.task_id for task in workflow_run_tasks],
|
task_ids=[task.task_id for task in workflow_run_tasks],
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
statuses=[StepStatus.completed],
|
|
||||||
)
|
)
|
||||||
if total_step_count >= max_steps:
|
if total_step_count >= max_steps:
|
||||||
LOG.info("Task v2 failed - run out of steps", max_steps=max_steps, workflow_run_id=workflow_run_id)
|
LOG.info("Task v2 failed - run out of steps", max_steps=max_steps, workflow_run_id=workflow_run_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user