max task steps for task v2 (#1877)
This commit is contained in:
@@ -343,6 +343,26 @@ class AgentDB:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_total_step_count_by_task_ids(
|
||||
self, task_ids: list[str], organization_id: str | None = None, statuses: list[StepStatus] | None = None
|
||||
) -> int:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(func.count())
|
||||
.where(StepModel.task_id.in_(task_ids))
|
||||
.filter_by(organization_id=organization_id)
|
||||
)
|
||||
if statuses:
|
||||
query = query.filter(StepModel.status.in_(statuses))
|
||||
return (await session.scalars(query)).scalar()
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]:
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
|
||||
@@ -52,7 +52,7 @@ class AsyncExecutor(abc.ABC):
|
||||
background_tasks: BackgroundTasks | None,
|
||||
organization_id: str,
|
||||
task_v2_id: str,
|
||||
max_iterations_override: int | str | None,
|
||||
max_steps_override: int | str | None,
|
||||
browser_session_id: str | None,
|
||||
**kwargs: dict,
|
||||
) -> None:
|
||||
@@ -144,7 +144,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
background_tasks: BackgroundTasks | None,
|
||||
organization_id: str,
|
||||
task_v2_id: str,
|
||||
max_iterations_override: int | str | None,
|
||||
max_steps_override: int | str | None,
|
||||
browser_session_id: str | None,
|
||||
**kwargs: dict,
|
||||
) -> None:
|
||||
@@ -177,6 +177,6 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||
task_v2_service.run_task_v2,
|
||||
organization=organization,
|
||||
task_v2_id=task_v2_id,
|
||||
max_iterations_override=max_iterations_override,
|
||||
max_steps_override=max_steps_override,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
|
||||
@@ -1229,9 +1229,14 @@ async def create_task_v2(
|
||||
data: TaskV2Request,
|
||||
organization: Organization = Depends(org_auth_service.get_current_org),
|
||||
x_max_iterations_override: Annotated[int | str | None, Header()] = None,
|
||||
x_max_steps_override: Annotated[int | str | None, Header()] = None,
|
||||
) -> dict[str, Any]:
|
||||
if x_max_iterations_override:
|
||||
LOG.info("Overriding max iterations for task v2", max_iterations_override=x_max_iterations_override)
|
||||
if x_max_iterations_override or x_max_steps_override:
|
||||
LOG.info(
|
||||
"Overriding max steps for task v2",
|
||||
max_iterations_override=x_max_iterations_override,
|
||||
max_steps_override=x_max_steps_override,
|
||||
)
|
||||
|
||||
try:
|
||||
task_v2 = await task_v2_service.initialize_task_v2(
|
||||
@@ -1256,7 +1261,7 @@ async def create_task_v2(
|
||||
background_tasks=background_tasks,
|
||||
organization_id=organization.organization_id,
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
max_iterations_override=x_max_iterations_override,
|
||||
max_steps_override=x_max_steps_override or x_max_iterations_override,
|
||||
browser_session_id=data.browser_session_id,
|
||||
)
|
||||
return task_v2.model_dump(by_alias=True)
|
||||
|
||||
@@ -8,6 +8,7 @@ import httpx
|
||||
import structlog
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import FailedToSendWebhook, TaskTerminationError, TaskV2NotFound, UrlGenerationFailure
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.prompts import prompt_engine
|
||||
@@ -17,6 +18,7 @@ from skyvern.forge.sdk.core.hashing import generate_url_hash
|
||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
|
||||
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
||||
from skyvern.forge.sdk.models import StepStatus
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization
|
||||
from skyvern.forge.sdk.schemas.task_runs import TaskRunType
|
||||
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType
|
||||
@@ -215,7 +217,7 @@ async def run_task_v2(
|
||||
organization: Organization,
|
||||
task_v2_id: str,
|
||||
request_id: str | None = None,
|
||||
max_iterations_override: str | int | None = None,
|
||||
max_steps_override: str | int | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
) -> TaskV2:
|
||||
organization_id = organization.organization_id
|
||||
@@ -243,7 +245,7 @@ async def run_task_v2(
|
||||
organization=organization,
|
||||
task_v2=task_v2,
|
||||
request_id=request_id,
|
||||
max_iterations_override=max_iterations_override,
|
||||
max_steps_override=max_steps_override,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
except TaskTerminationError as e:
|
||||
@@ -292,7 +294,7 @@ async def run_task_v2_helper(
|
||||
organization: Organization,
|
||||
task_v2: TaskV2,
|
||||
request_id: str | None = None,
|
||||
max_iterations_override: str | int | None = None,
|
||||
max_steps_override: str | int | None = None,
|
||||
browser_session_id: str | None = None,
|
||||
) -> tuple[Workflow, WorkflowRun, TaskV2] | tuple[None, None, TaskV2]:
|
||||
organization_id = organization.organization_id
|
||||
@@ -320,15 +322,15 @@ async def run_task_v2_helper(
|
||||
)
|
||||
return None, None, task_v2
|
||||
|
||||
int_max_iterations_override = None
|
||||
if max_iterations_override:
|
||||
int_max_steps_override = None
|
||||
if max_steps_override:
|
||||
try:
|
||||
int_max_iterations_override = int(max_iterations_override)
|
||||
LOG.info("max_iterationss_override is set", max_iterations_override=int_max_iterations_override)
|
||||
int_max_steps_override = int(max_steps_override)
|
||||
LOG.info("max_steps_override is set", max_steps=int_max_steps_override)
|
||||
except ValueError:
|
||||
LOG.info(
|
||||
"max_iterations_override isn't an integer, won't override",
|
||||
max_iterations_override=max_iterations_override,
|
||||
"max_steps_override isn't an integer, won't override",
|
||||
max_steps_override=max_steps_override,
|
||||
)
|
||||
|
||||
workflow_run_id = task_v2.workflow_run_id
|
||||
@@ -375,8 +377,8 @@ async def run_task_v2_helper(
|
||||
yaml_blocks: list[BLOCK_YAML_TYPES] = []
|
||||
yaml_parameters: list[PARAMETER_YAML_TYPES] = []
|
||||
|
||||
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS
|
||||
for i in range(max_iterations):
|
||||
max_steps = int_max_steps_override or settings.MAX_STEPS_PER_TASK_V2
|
||||
for i in range(DEFAULT_MAX_ITERATIONS):
|
||||
# validate the task execution
|
||||
await app.AGENT_FUNCTION.validate_task_execution(
|
||||
organization_id=organization_id,
|
||||
@@ -704,10 +706,28 @@ async def run_task_v2_helper(
|
||||
screenshots=completion_screenshots,
|
||||
)
|
||||
break
|
||||
|
||||
# total step number validation
|
||||
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
|
||||
total_step_count = await app.DATABASE.get_total_step_count_by_task_ids(
|
||||
task_ids=[task.task_id for task in workflow_run_tasks],
|
||||
organization_id=organization_id,
|
||||
statuses=[StepStatus.completed],
|
||||
)
|
||||
if total_step_count >= max_steps:
|
||||
LOG.info("Task v2 failed - run out of steps", max_steps=max_steps, workflow_run_id=workflow_run_id)
|
||||
await mark_task_v2_as_failed(
|
||||
task_v2_id=task_v2_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
failure_reason=f'Reached the max number of {max_steps} steps. If you need more steps, update the "Max Steps Override" configuration when running the task. Or add/update the "x-max-steps-override" header with your desired number of steps in the API request.',
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return workflow, workflow_run, task_v2
|
||||
else:
|
||||
LOG.info(
|
||||
"Task v2 failed - run out of iterations",
|
||||
max_iterations=max_iterations,
|
||||
max_iterations=DEFAULT_MAX_ITERATIONS,
|
||||
max_steps=max_steps,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
task_v2 = await mark_task_v2_as_failed(
|
||||
|
||||
@@ -2122,7 +2122,8 @@ class TaskV2Block(Block):
|
||||
url: str | None = None
|
||||
totp_verification_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
max_iterations: int = 10
|
||||
max_iterations: int = settings.MAX_ITERATIONS_PER_TASK_V2
|
||||
max_steps: int = settings.MAX_STEPS_PER_TASK_V2
|
||||
|
||||
def get_all_parameters(
|
||||
self,
|
||||
@@ -2175,7 +2176,7 @@ class TaskV2Block(Block):
|
||||
organization=organization,
|
||||
task_v2_id=task_v2.observer_cruise_id,
|
||||
request_id=None,
|
||||
max_iterations_override=self.max_iterations,
|
||||
max_steps_override=self.max_steps,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
result_dict = None
|
||||
|
||||
@@ -337,7 +337,8 @@ class TaskV2BlockYAML(BlockYAML):
|
||||
url: str | None = None
|
||||
totp_verification_url: str | None = None
|
||||
totp_identifier: str | None = None
|
||||
max_iterations: int = 10
|
||||
max_iterations: int = settings.MAX_ITERATIONS_PER_TASK_V2
|
||||
max_steps: int = settings.MAX_STEPS_PER_TASK_V2
|
||||
|
||||
|
||||
PARAMETER_YAML_SUBCLASSES = (
|
||||
|
||||
@@ -1855,6 +1855,7 @@ class WorkflowService:
|
||||
totp_verification_url=block_yaml.totp_verification_url,
|
||||
totp_identifier=block_yaml.totp_identifier,
|
||||
max_iterations=block_yaml.max_iterations,
|
||||
max_steps=block_yaml.max_steps,
|
||||
output_parameter=output_parameter,
|
||||
)
|
||||
elif block_yaml.block_type == BlockType.GOTO_URL:
|
||||
|
||||
Reference in New Issue
Block a user