max task steps for task v2 (#1877)

This commit is contained in:
Shuchang Zheng
2025-03-04 01:07:07 -05:00
committed by GitHub
parent 618070840f
commit d31e4bf268
15 changed files with 90 additions and 40 deletions

View File

@@ -343,6 +343,26 @@ class AgentDB:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_total_step_count_by_task_ids(
self, task_ids: list[str], organization_id: str | None = None, statuses: list[StepStatus] | None = None
) -> int:
try:
async with self.Session() as session:
query = (
select(func.count())
.where(StepModel.task_id.in_(task_ids))
.filter_by(organization_id=organization_id)
)
if statuses:
query = query.filter(StepModel.status.in_(statuses))
return (await session.scalars(query)).scalar()
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]:
try:
async with self.Session() as session:

View File

@@ -52,7 +52,7 @@ class AsyncExecutor(abc.ABC):
background_tasks: BackgroundTasks | None,
organization_id: str,
task_v2_id: str,
max_iterations_override: int | str | None,
max_steps_override: int | str | None,
browser_session_id: str | None,
**kwargs: dict,
) -> None:
@@ -144,7 +144,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
background_tasks: BackgroundTasks | None,
organization_id: str,
task_v2_id: str,
max_iterations_override: int | str | None,
max_steps_override: int | str | None,
browser_session_id: str | None,
**kwargs: dict,
) -> None:
@@ -177,6 +177,6 @@ class BackgroundTaskExecutor(AsyncExecutor):
task_v2_service.run_task_v2,
organization=organization,
task_v2_id=task_v2_id,
max_iterations_override=max_iterations_override,
max_steps_override=max_steps_override,
browser_session_id=browser_session_id,
)

View File

@@ -1229,9 +1229,14 @@ async def create_task_v2(
data: TaskV2Request,
organization: Organization = Depends(org_auth_service.get_current_org),
x_max_iterations_override: Annotated[int | str | None, Header()] = None,
x_max_steps_override: Annotated[int | str | None, Header()] = None,
) -> dict[str, Any]:
if x_max_iterations_override:
LOG.info("Overriding max iterations for task v2", max_iterations_override=x_max_iterations_override)
if x_max_iterations_override or x_max_steps_override:
LOG.info(
"Overriding max steps for task v2",
max_iterations_override=x_max_iterations_override,
max_steps_override=x_max_steps_override,
)
try:
task_v2 = await task_v2_service.initialize_task_v2(
@@ -1256,7 +1261,7 @@ async def create_task_v2(
background_tasks=background_tasks,
organization_id=organization.organization_id,
task_v2_id=task_v2.observer_cruise_id,
max_iterations_override=x_max_iterations_override,
max_steps_override=x_max_steps_override or x_max_iterations_override,
browser_session_id=data.browser_session_id,
)
return task_v2.model_dump(by_alias=True)

View File

@@ -8,6 +8,7 @@ import httpx
import structlog
from sqlalchemy.exc import OperationalError
from skyvern.config import settings
from skyvern.exceptions import FailedToSendWebhook, TaskTerminationError, TaskV2NotFound, UrlGenerationFailure
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
@@ -17,6 +18,7 @@ from skyvern.forge.sdk.core.hashing import generate_url_hash
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.models import StepStatus
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType
@@ -215,7 +217,7 @@ async def run_task_v2(
organization: Organization,
task_v2_id: str,
request_id: str | None = None,
max_iterations_override: str | int | None = None,
max_steps_override: str | int | None = None,
browser_session_id: str | None = None,
) -> TaskV2:
organization_id = organization.organization_id
@@ -243,7 +245,7 @@ async def run_task_v2(
organization=organization,
task_v2=task_v2,
request_id=request_id,
max_iterations_override=max_iterations_override,
max_steps_override=max_steps_override,
browser_session_id=browser_session_id,
)
except TaskTerminationError as e:
@@ -292,7 +294,7 @@ async def run_task_v2_helper(
organization: Organization,
task_v2: TaskV2,
request_id: str | None = None,
max_iterations_override: str | int | None = None,
max_steps_override: str | int | None = None,
browser_session_id: str | None = None,
) -> tuple[Workflow, WorkflowRun, TaskV2] | tuple[None, None, TaskV2]:
organization_id = organization.organization_id
@@ -320,15 +322,15 @@ async def run_task_v2_helper(
)
return None, None, task_v2
int_max_iterations_override = None
if max_iterations_override:
int_max_steps_override = None
if max_steps_override:
try:
int_max_iterations_override = int(max_iterations_override)
LOG.info("max_iterationss_override is set", max_iterations_override=int_max_iterations_override)
int_max_steps_override = int(max_steps_override)
LOG.info("max_steps_override is set", max_steps=int_max_steps_override)
except ValueError:
LOG.info(
"max_iterations_override isn't an integer, won't override",
max_iterations_override=max_iterations_override,
"max_steps_override isn't an integer, won't override",
max_steps_override=max_steps_override,
)
workflow_run_id = task_v2.workflow_run_id
@@ -375,8 +377,8 @@ async def run_task_v2_helper(
yaml_blocks: list[BLOCK_YAML_TYPES] = []
yaml_parameters: list[PARAMETER_YAML_TYPES] = []
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
for i in range(max_iterations):
max_steps = int_max_steps_override or settings.MAX_STEPS_PER_TASK_V2
for i in range(DEFAULT_MAX_ITERATIONS):
# validate the task execution
await app.AGENT_FUNCTION.validate_task_execution(
organization_id=organization_id,
@@ -704,10 +706,28 @@ async def run_task_v2_helper(
screenshots=completion_screenshots,
)
break
# total step number validation
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
total_step_count = await app.DATABASE.get_total_step_count_by_task_ids(
task_ids=[task.task_id for task in workflow_run_tasks],
organization_id=organization_id,
statuses=[StepStatus.completed],
)
if total_step_count >= max_steps:
LOG.info("Task v2 failed - run out of steps", max_steps=max_steps, workflow_run_id=workflow_run_id)
await mark_task_v2_as_failed(
task_v2_id=task_v2_id,
workflow_run_id=workflow_run_id,
failure_reason=f'Reached the max number of {max_steps} steps. If you need more steps, update the "Max Steps Override" configuration when running the task. Or add/update the "x-max-steps-override" header with your desired number of steps in the API request.',
organization_id=organization_id,
)
return workflow, workflow_run, task_v2
else:
LOG.info(
"Task v2 failed - run out of iterations",
max_iterations=max_iterations,
max_iterations=DEFAULT_MAX_ITERATIONS,
max_steps=max_steps,
workflow_run_id=workflow_run_id,
)
task_v2 = await mark_task_v2_as_failed(

View File

@@ -2122,7 +2122,8 @@ class TaskV2Block(Block):
url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
max_iterations: int = 10
max_iterations: int = settings.MAX_ITERATIONS_PER_TASK_V2
max_steps: int = settings.MAX_STEPS_PER_TASK_V2
def get_all_parameters(
self,
@@ -2175,7 +2176,7 @@ class TaskV2Block(Block):
organization=organization,
task_v2_id=task_v2.observer_cruise_id,
request_id=None,
max_iterations_override=self.max_iterations,
max_steps_override=self.max_steps,
browser_session_id=browser_session_id,
)
result_dict = None

View File

@@ -337,7 +337,8 @@ class TaskV2BlockYAML(BlockYAML):
url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
max_iterations: int = 10
max_iterations: int = settings.MAX_ITERATIONS_PER_TASK_V2
max_steps: int = settings.MAX_STEPS_PER_TASK_V2
PARAMETER_YAML_SUBCLASSES = (

View File

@@ -1855,6 +1855,7 @@ class WorkflowService:
totp_verification_url=block_yaml.totp_verification_url,
totp_identifier=block_yaml.totp_identifier,
max_iterations=block_yaml.max_iterations,
max_steps=block_yaml.max_steps,
output_parameter=output_parameter,
)
elif block_yaml.block_type == BlockType.GOTO_URL: