max task steps for task v2 (#1877)

This commit is contained in:
Shuchang Zheng
2025-03-04 01:07:07 -05:00
committed by GitHub
parent 618070840f
commit d31e4bf268
15 changed files with 90 additions and 40 deletions

View File

@@ -42,6 +42,7 @@ import {
generateUniqueEmail, generateUniqueEmail,
} from "../data/sampleTaskData"; } from "../data/sampleTaskData";
import { ExampleCasePill } from "./ExampleCasePill"; import { ExampleCasePill } from "./ExampleCasePill";
import { MAX_STEPS_DEFAULT } from "@/routes/workflows/editor/nodes/Taskv2Node/types";
function createTemplateTaskFromTaskGenerationParameters( function createTemplateTaskFromTaskGenerationParameters(
values: TaskGenerationApiResponse, values: TaskGenerationApiResponse,
@@ -167,7 +168,7 @@ function PromptBox() {
}, },
{ {
headers: { headers: {
"x-max-iterations-override": maxStepsOverride, "x-max-steps-override": maxStepsOverride,
}, },
}, },
); );
@@ -402,6 +403,7 @@ function PromptBox() {
</div> </div>
<Input <Input
value={maxStepsOverride ?? ""} value={maxStepsOverride ?? ""}
placeholder={`Default: ${MAX_STEPS_DEFAULT}`}
onChange={(event) => { onChange={(event) => {
setMaxStepsOverride(event.target.value); setMaxStepsOverride(event.target.value);
}} }}

View File

@@ -41,8 +41,8 @@ export const helpTooltips = {
task: baseHelpTooltipContent, task: baseHelpTooltipContent,
taskv2: { taskv2: {
...baseHelpTooltipContent, ...baseHelpTooltipContent,
maxIterations: maxSteps:
"The maximum number of iterations this task will take to achieve its goal.", "The maximum number of steps this task will take to achieve its goal.",
}, },
navigation: baseHelpTooltipContent, navigation: baseHelpTooltipContent,
extraction: { extraction: {

View File

@@ -19,7 +19,7 @@ import { useIsFirstBlockInWorkflow } from "../../hooks/useIsFirstNodeInWorkflow"
import { NodeActionMenu } from "../NodeActionMenu"; import { NodeActionMenu } from "../NodeActionMenu";
import { WorkflowBlockIcon } from "../WorkflowBlockIcon"; import { WorkflowBlockIcon } from "../WorkflowBlockIcon";
import { EditableNodeTitle } from "../components/EditableNodeTitle"; import { EditableNodeTitle } from "../components/EditableNodeTitle";
import { MAX_ITERATIONS_DEFAULT, type Taskv2Node } from "./types"; import { MAX_STEPS_DEFAULT, type Taskv2Node } from "./types";
function Taskv2Node({ id, data, type }: NodeProps<Taskv2Node>) { function Taskv2Node({ id, data, type }: NodeProps<Taskv2Node>) {
const { updateNodeData } = useReactFlow(); const { updateNodeData } = useReactFlow();
@@ -37,7 +37,7 @@ function Taskv2Node({ id, data, type }: NodeProps<Taskv2Node>) {
url: data.url, url: data.url,
totpVerificationUrl: data.totpVerificationUrl, totpVerificationUrl: data.totpVerificationUrl,
totpIdentifier: data.totpIdentifier, totpIdentifier: data.totpIdentifier,
maxIterations: data.maxIterations, maxSteps: data.maxSteps,
}); });
function handleChange(key: string, value: unknown) { function handleChange(key: string, value: unknown) {
@@ -132,19 +132,17 @@ function Taskv2Node({ id, data, type }: NodeProps<Taskv2Node>) {
<div className="space-y-2"> <div className="space-y-2">
<div className="flex gap-2"> <div className="flex gap-2">
<Label className="text-xs text-slate-300"> <Label className="text-xs text-slate-300">
Max Iterations Max Steps
</Label> </Label>
<HelpTooltip <HelpTooltip content={helpTooltips[type]["maxSteps"]} />
content={helpTooltips[type]["maxIterations"]}
/>
</div> </div>
<Input <Input
type="number" type="number"
placeholder="10" placeholder="10"
className="nopan text-xs" className="nopan text-xs"
value={data.maxIterations ?? MAX_ITERATIONS_DEFAULT} value={data.maxSteps ?? MAX_STEPS_DEFAULT}
onChange={(event) => { onChange={(event) => {
handleChange("maxIterations", Number(event.target.value)); handleChange("maxSteps", Number(event.target.value));
}} }}
/> />
</div> </div>

View File

@@ -1,14 +1,14 @@
import { Node } from "@xyflow/react"; import { Node } from "@xyflow/react";
import { NodeBaseData } from "../types"; import { NodeBaseData } from "../types";
export const MAX_ITERATIONS_DEFAULT = 10; export const MAX_STEPS_DEFAULT = 25;
export type Taskv2NodeData = NodeBaseData & { export type Taskv2NodeData = NodeBaseData & {
prompt: string; prompt: string;
url: string; url: string;
totpVerificationUrl: string | null; totpVerificationUrl: string | null;
totpIdentifier: string | null; totpIdentifier: string | null;
maxIterations: number | null; maxSteps: number | null;
}; };
export type Taskv2Node = Node<Taskv2NodeData, "taskv2">; export type Taskv2Node = Node<Taskv2NodeData, "taskv2">;
@@ -21,7 +21,7 @@ export const taskv2NodeDefaultData: Taskv2NodeData = {
url: "", url: "",
totpIdentifier: null, totpIdentifier: null,
totpVerificationUrl: null, totpVerificationUrl: null,
maxIterations: 10, maxSteps: MAX_STEPS_DEFAULT,
}; };
export function isTaskV2Node(node: Node): node is Taskv2Node { export function isTaskV2Node(node: Node): node is Taskv2Node {

View File

@@ -219,7 +219,7 @@ function convertToNode(
...commonData, ...commonData,
prompt: block.prompt, prompt: block.prompt,
url: block.url ?? "", url: block.url ?? "",
maxIterations: block.max_iterations, maxSteps: block.max_steps,
totpIdentifier: block.totp_identifier, totpIdentifier: block.totp_identifier,
totpVerificationUrl: block.totp_verification_url, totpVerificationUrl: block.totp_verification_url,
}, },
@@ -928,7 +928,7 @@ function getWorkflowBlock(node: WorkflowBlockNode): BlockYAML {
...base, ...base,
block_type: "task_v2", block_type: "task_v2",
prompt: node.data.prompt, prompt: node.data.prompt,
max_iterations: node.data.maxIterations, max_steps: node.data.maxSteps,
totp_identifier: node.data.totpIdentifier, totp_identifier: node.data.totpIdentifier,
totp_verification_url: node.data.totpVerificationUrl, totp_verification_url: node.data.totpVerificationUrl,
url: node.data.url, url: node.data.url,
@@ -1608,7 +1608,7 @@ function convertBlocksToBlockYAML(
block_type: "task_v2", block_type: "task_v2",
prompt: block.prompt, prompt: block.prompt,
url: block.url, url: block.url,
max_iterations: block.max_iterations, max_steps: block.max_steps,
totp_identifier: block.totp_identifier, totp_identifier: block.totp_identifier,
totp_verification_url: block.totp_verification_url, totp_verification_url: block.totp_verification_url,
}; };

View File

@@ -255,7 +255,7 @@ export type Taskv2Block = WorkflowBlockBase & {
url: string | null; url: string | null;
totp_verification_url: string | null; totp_verification_url: string | null;
totp_identifier: string | null; totp_identifier: string | null;
max_iterations: number | null; max_steps: number | null;
}; };
export type ForLoopBlock = WorkflowBlockBase & { export type ForLoopBlock = WorkflowBlockBase & {

View File

@@ -140,7 +140,7 @@ export type Taskv2BlockYAML = BlockYAMLBase & {
prompt: string; prompt: string;
totp_verification_url: string | null; totp_verification_url: string | null;
totp_identifier: string | null; totp_identifier: string | null;
max_iterations: number | null; max_steps: number | null;
}; };
export type ValidationBlockYAML = BlockYAMLBase & { export type ValidationBlockYAML = BlockYAMLBase & {

View File

@@ -20,6 +20,8 @@ class Settings(BaseSettings):
BROWSER_LOADING_TIMEOUT_MS: int = 120000 BROWSER_LOADING_TIMEOUT_MS: int = 120000
OPTION_LOADING_TIMEOUT_MS: int = 600000 OPTION_LOADING_TIMEOUT_MS: int = 600000
MAX_STEPS_PER_RUN: int = 10 MAX_STEPS_PER_RUN: int = 10
MAX_STEPS_PER_TASK_V2: int = 25
MAX_ITERATIONS_PER_TASK_V2: int = 10
MAX_NUM_SCREENSHOTS: int = 10 MAX_NUM_SCREENSHOTS: int = 10
# Ratio should be between 0 and 1. # Ratio should be between 0 and 1.
# If the task has been running for more steps than this ratio of the max steps per run, then we'll log a warning. # If the task has been running for more steps than this ratio of the max steps per run, then we'll log a warning.

View File

@@ -343,6 +343,26 @@ class AgentDB:
LOG.error("SQLAlchemyError", exc_info=True) LOG.error("SQLAlchemyError", exc_info=True)
raise raise
async def get_total_step_count_by_task_ids(
self, task_ids: list[str], organization_id: str | None = None, statuses: list[StepStatus] | None = None
) -> int:
try:
async with self.Session() as session:
query = (
select(func.count())
.where(StepModel.task_id.in_(task_ids))
.filter_by(organization_id=organization_id)
)
if statuses:
query = query.filter(StepModel.status.in_(statuses))
return (await session.scalars(query)).scalar()
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]: async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]:
try: try:
async with self.Session() as session: async with self.Session() as session:

View File

@@ -52,7 +52,7 @@ class AsyncExecutor(abc.ABC):
background_tasks: BackgroundTasks | None, background_tasks: BackgroundTasks | None,
organization_id: str, organization_id: str,
task_v2_id: str, task_v2_id: str,
max_iterations_override: int | str | None, max_steps_override: int | str | None,
browser_session_id: str | None, browser_session_id: str | None,
**kwargs: dict, **kwargs: dict,
) -> None: ) -> None:
@@ -144,7 +144,7 @@ class BackgroundTaskExecutor(AsyncExecutor):
background_tasks: BackgroundTasks | None, background_tasks: BackgroundTasks | None,
organization_id: str, organization_id: str,
task_v2_id: str, task_v2_id: str,
max_iterations_override: int | str | None, max_steps_override: int | str | None,
browser_session_id: str | None, browser_session_id: str | None,
**kwargs: dict, **kwargs: dict,
) -> None: ) -> None:
@@ -177,6 +177,6 @@ class BackgroundTaskExecutor(AsyncExecutor):
task_v2_service.run_task_v2, task_v2_service.run_task_v2,
organization=organization, organization=organization,
task_v2_id=task_v2_id, task_v2_id=task_v2_id,
max_iterations_override=max_iterations_override, max_steps_override=max_steps_override,
browser_session_id=browser_session_id, browser_session_id=browser_session_id,
) )

View File

@@ -1229,9 +1229,14 @@ async def create_task_v2(
data: TaskV2Request, data: TaskV2Request,
organization: Organization = Depends(org_auth_service.get_current_org), organization: Organization = Depends(org_auth_service.get_current_org),
x_max_iterations_override: Annotated[int | str | None, Header()] = None, x_max_iterations_override: Annotated[int | str | None, Header()] = None,
x_max_steps_override: Annotated[int | str | None, Header()] = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
if x_max_iterations_override: if x_max_iterations_override or x_max_steps_override:
LOG.info("Overriding max iterations for task v2", max_iterations_override=x_max_iterations_override) LOG.info(
"Overriding max steps for task v2",
max_iterations_override=x_max_iterations_override,
max_steps_override=x_max_steps_override,
)
try: try:
task_v2 = await task_v2_service.initialize_task_v2( task_v2 = await task_v2_service.initialize_task_v2(
@@ -1256,7 +1261,7 @@ async def create_task_v2(
background_tasks=background_tasks, background_tasks=background_tasks,
organization_id=organization.organization_id, organization_id=organization.organization_id,
task_v2_id=task_v2.observer_cruise_id, task_v2_id=task_v2.observer_cruise_id,
max_iterations_override=x_max_iterations_override, max_steps_override=x_max_steps_override or x_max_iterations_override,
browser_session_id=data.browser_session_id, browser_session_id=data.browser_session_id,
) )
return task_v2.model_dump(by_alias=True) return task_v2.model_dump(by_alias=True)

View File

@@ -8,6 +8,7 @@ import httpx
import structlog import structlog
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from skyvern.config import settings
from skyvern.exceptions import FailedToSendWebhook, TaskTerminationError, TaskV2NotFound, UrlGenerationFailure from skyvern.exceptions import FailedToSendWebhook, TaskTerminationError, TaskV2NotFound, UrlGenerationFailure
from skyvern.forge import app from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine from skyvern.forge.prompts import prompt_engine
@@ -17,6 +18,7 @@ from skyvern.forge.sdk.core.hashing import generate_url_hash
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
from skyvern.forge.sdk.models import StepStatus
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.task_runs import TaskRunType from skyvern.forge.sdk.schemas.task_runs import TaskRunType
from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType from skyvern.forge.sdk.schemas.task_v2 import TaskV2, TaskV2Metadata, TaskV2Status, ThoughtScenario, ThoughtType
@@ -215,7 +217,7 @@ async def run_task_v2(
organization: Organization, organization: Organization,
task_v2_id: str, task_v2_id: str,
request_id: str | None = None, request_id: str | None = None,
max_iterations_override: str | int | None = None, max_steps_override: str | int | None = None,
browser_session_id: str | None = None, browser_session_id: str | None = None,
) -> TaskV2: ) -> TaskV2:
organization_id = organization.organization_id organization_id = organization.organization_id
@@ -243,7 +245,7 @@ async def run_task_v2(
organization=organization, organization=organization,
task_v2=task_v2, task_v2=task_v2,
request_id=request_id, request_id=request_id,
max_iterations_override=max_iterations_override, max_steps_override=max_steps_override,
browser_session_id=browser_session_id, browser_session_id=browser_session_id,
) )
except TaskTerminationError as e: except TaskTerminationError as e:
@@ -292,7 +294,7 @@ async def run_task_v2_helper(
organization: Organization, organization: Organization,
task_v2: TaskV2, task_v2: TaskV2,
request_id: str | None = None, request_id: str | None = None,
max_iterations_override: str | int | None = None, max_steps_override: str | int | None = None,
browser_session_id: str | None = None, browser_session_id: str | None = None,
) -> tuple[Workflow, WorkflowRun, TaskV2] | tuple[None, None, TaskV2]: ) -> tuple[Workflow, WorkflowRun, TaskV2] | tuple[None, None, TaskV2]:
organization_id = organization.organization_id organization_id = organization.organization_id
@@ -320,15 +322,15 @@ async def run_task_v2_helper(
) )
return None, None, task_v2 return None, None, task_v2
int_max_iterations_override = None int_max_steps_override = None
if max_iterations_override: if max_steps_override:
try: try:
int_max_iterations_override = int(max_iterations_override) int_max_steps_override = int(max_steps_override)
LOG.info("max_iterationss_override is set", max_iterations_override=int_max_iterations_override) LOG.info("max_steps_override is set", max_steps=int_max_steps_override)
except ValueError: except ValueError:
LOG.info( LOG.info(
"max_iterations_override isn't an integer, won't override", "max_steps_override isn't an integer, won't override",
max_iterations_override=max_iterations_override, max_steps_override=max_steps_override,
) )
workflow_run_id = task_v2.workflow_run_id workflow_run_id = task_v2.workflow_run_id
@@ -375,8 +377,8 @@ async def run_task_v2_helper(
yaml_blocks: list[BLOCK_YAML_TYPES] = [] yaml_blocks: list[BLOCK_YAML_TYPES] = []
yaml_parameters: list[PARAMETER_YAML_TYPES] = [] yaml_parameters: list[PARAMETER_YAML_TYPES] = []
max_iterations = int_max_iterations_override or DEFAULT_MAX_ITERATIONS max_steps = int_max_steps_override or settings.MAX_STEPS_PER_TASK_V2
for i in range(max_iterations): for i in range(DEFAULT_MAX_ITERATIONS):
# validate the task execution # validate the task execution
await app.AGENT_FUNCTION.validate_task_execution( await app.AGENT_FUNCTION.validate_task_execution(
organization_id=organization_id, organization_id=organization_id,
@@ -704,10 +706,28 @@ async def run_task_v2_helper(
screenshots=completion_screenshots, screenshots=completion_screenshots,
) )
break break
# total step number validation
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
total_step_count = await app.DATABASE.get_total_step_count_by_task_ids(
task_ids=[task.task_id for task in workflow_run_tasks],
organization_id=organization_id,
statuses=[StepStatus.completed],
)
if total_step_count >= max_steps:
LOG.info("Task v2 failed - run out of steps", max_steps=max_steps, workflow_run_id=workflow_run_id)
await mark_task_v2_as_failed(
task_v2_id=task_v2_id,
workflow_run_id=workflow_run_id,
failure_reason=f'Reached the max number of {max_steps} steps. If you need more steps, update the "Max Steps Override" configuration when running the task. Or add/update the "x-max-steps-override" header with your desired number of steps in the API request.',
organization_id=organization_id,
)
return workflow, workflow_run, task_v2
else: else:
LOG.info( LOG.info(
"Task v2 failed - run out of iterations", "Task v2 failed - run out of iterations",
max_iterations=max_iterations, max_iterations=DEFAULT_MAX_ITERATIONS,
max_steps=max_steps,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
task_v2 = await mark_task_v2_as_failed( task_v2 = await mark_task_v2_as_failed(

View File

@@ -2122,7 +2122,8 @@ class TaskV2Block(Block):
url: str | None = None url: str | None = None
totp_verification_url: str | None = None totp_verification_url: str | None = None
totp_identifier: str | None = None totp_identifier: str | None = None
max_iterations: int = 10 max_iterations: int = settings.MAX_ITERATIONS_PER_TASK_V2
max_steps: int = settings.MAX_STEPS_PER_TASK_V2
def get_all_parameters( def get_all_parameters(
self, self,
@@ -2175,7 +2176,7 @@ class TaskV2Block(Block):
organization=organization, organization=organization,
task_v2_id=task_v2.observer_cruise_id, task_v2_id=task_v2.observer_cruise_id,
request_id=None, request_id=None,
max_iterations_override=self.max_iterations, max_steps_override=self.max_steps,
browser_session_id=browser_session_id, browser_session_id=browser_session_id,
) )
result_dict = None result_dict = None

View File

@@ -337,7 +337,8 @@ class TaskV2BlockYAML(BlockYAML):
url: str | None = None url: str | None = None
totp_verification_url: str | None = None totp_verification_url: str | None = None
totp_identifier: str | None = None totp_identifier: str | None = None
max_iterations: int = 10 max_iterations: int = settings.MAX_ITERATIONS_PER_TASK_V2
max_steps: int = settings.MAX_STEPS_PER_TASK_V2
PARAMETER_YAML_SUBCLASSES = ( PARAMETER_YAML_SUBCLASSES = (

View File

@@ -1855,6 +1855,7 @@ class WorkflowService:
totp_verification_url=block_yaml.totp_verification_url, totp_verification_url=block_yaml.totp_verification_url,
totp_identifier=block_yaml.totp_identifier, totp_identifier=block_yaml.totp_identifier,
max_iterations=block_yaml.max_iterations, max_iterations=block_yaml.max_iterations,
max_steps=block_yaml.max_steps,
output_parameter=output_parameter, output_parameter=output_parameter,
) )
elif block_yaml.block_type == BlockType.GOTO_URL: elif block_yaml.block_type == BlockType.GOTO_URL: