Add loop-scoped DAG execution for conditionals inside for-loops - backend (#4302)
This commit is contained in:
@@ -10,7 +10,7 @@ import re
|
||||
import smtplib
|
||||
import textwrap
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from datetime import datetime
|
||||
from email.message import EmailMessage
|
||||
from pathlib import Path
|
||||
@@ -77,6 +77,7 @@ from skyvern.forge.sdk.workflow.exceptions import (
|
||||
InsecureCodeDetected,
|
||||
InvalidEmailClientConfiguration,
|
||||
InvalidFileType,
|
||||
InvalidWorkflowDefinition,
|
||||
MissingJinjaVariables,
|
||||
NoIterableValueFound,
|
||||
NoValidEmailRecipient,
|
||||
@@ -1336,6 +1337,71 @@ class ForLoopBlock(Block):
|
||||
output_parameter=output_param,
|
||||
)
|
||||
|
||||
def _build_loop_graph(
|
||||
self, blocks: list[BlockTypeVar]
|
||||
) -> tuple[str, dict[str, BlockTypeVar], dict[str, str | None]]:
|
||||
label_to_block: dict[str, BlockTypeVar] = {}
|
||||
default_next_map: dict[str, str | None] = {}
|
||||
|
||||
for block in blocks:
|
||||
if block.label in label_to_block:
|
||||
raise InvalidWorkflowDefinition(f"Duplicate block label detected in loop: {block.label}")
|
||||
label_to_block[block.label] = block
|
||||
default_next_map[block.label] = block.next_block_label
|
||||
|
||||
has_conditional_blocks = any(block.block_type == BlockType.CONDITIONAL for block in blocks)
|
||||
if not has_conditional_blocks:
|
||||
for idx, block in enumerate(blocks[:-1]):
|
||||
if default_next_map.get(block.label) is None:
|
||||
default_next_map[block.label] = blocks[idx + 1].label
|
||||
|
||||
adjacency: dict[str, set[str]] = {label: set() for label in label_to_block}
|
||||
incoming: dict[str, int] = {label: 0 for label in label_to_block}
|
||||
|
||||
def _add_edge(source: str, target: str | None) -> None:
|
||||
if not target:
|
||||
return
|
||||
if target not in label_to_block:
|
||||
raise InvalidWorkflowDefinition(
|
||||
f"Block {source} references unknown next_block_label {target} inside loop {self.label}"
|
||||
)
|
||||
# Allow multiple branches of a conditional to point to the same target
|
||||
# without double-counting the incoming edge.
|
||||
if target not in adjacency[source]:
|
||||
adjacency[source].add(target)
|
||||
incoming[target] += 1
|
||||
|
||||
for label, block in label_to_block.items():
|
||||
if block.block_type == BlockType.CONDITIONAL:
|
||||
for branch in block.ordered_branches:
|
||||
_add_edge(label, branch.next_block_label)
|
||||
else:
|
||||
_add_edge(label, default_next_map.get(label))
|
||||
|
||||
roots = [label for label, count in incoming.items() if count == 0]
|
||||
if not roots:
|
||||
raise InvalidWorkflowDefinition(f"No entry block found for loop {self.label}")
|
||||
if len(roots) > 1:
|
||||
raise InvalidWorkflowDefinition(
|
||||
f"Multiple entry blocks detected in loop {self.label} ({', '.join(sorted(roots))}); only one entry block is supported."
|
||||
)
|
||||
|
||||
queue: deque[str] = deque([roots[0]])
|
||||
visited_count = 0
|
||||
in_degree = dict(incoming)
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
visited_count += 1
|
||||
for neighbor in adjacency[node]:
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
queue.append(neighbor)
|
||||
|
||||
if visited_count != len(label_to_block):
|
||||
raise InvalidWorkflowDefinition(f"Loop {self.label} contains a cycle; DAG traversal is required.")
|
||||
|
||||
return roots[0], label_to_block, default_next_map
|
||||
|
||||
async def execute_loop_helper(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@@ -1349,6 +1415,8 @@ class ForLoopBlock(Block):
|
||||
block_outputs: list[BlockResult] = []
|
||||
current_block: BlockTypeVar | None = None
|
||||
|
||||
start_label, label_to_block, default_next_map = self._build_loop_graph(self.loop_blocks)
|
||||
|
||||
for loop_idx, loop_over_value in enumerate(loop_over_values):
|
||||
# Check max_iterations limit
|
||||
if loop_idx >= DEFAULT_MAX_LOOP_ITERATIONS:
|
||||
@@ -1379,7 +1447,6 @@ class ForLoopBlock(Block):
|
||||
|
||||
each_loop_output_values: list[dict[str, Any]] = []
|
||||
|
||||
# Track steps for current iteration
|
||||
iteration_step_count = 0
|
||||
LOG.info(
|
||||
f"ForLoopBlock: Starting iteration {loop_idx} with max_steps_per_iteration={DEFAULT_MAX_STEPS_PER_ITERATION}",
|
||||
@@ -1388,7 +1455,32 @@ class ForLoopBlock(Block):
|
||||
max_steps_per_iteration=DEFAULT_MAX_STEPS_PER_ITERATION,
|
||||
)
|
||||
|
||||
for block_idx, loop_block in enumerate(self.loop_blocks):
|
||||
block_idx = 0
|
||||
current_label: str | None = start_label
|
||||
while current_label:
|
||||
loop_block = label_to_block.get(current_label)
|
||||
if not loop_block:
|
||||
LOG.error(
|
||||
"Unable to find loop block with label in loop graph",
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_label=self.label,
|
||||
current_label=current_label,
|
||||
)
|
||||
failure_block_result = await self.build_block_result(
|
||||
success=False,
|
||||
status=BlockStatus.failed,
|
||||
failure_reason=f"Unable to find block with label {current_label} inside loop {self.label}",
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
block_outputs.append(failure_block_result)
|
||||
outputs_with_loop_values.append(each_loop_output_values)
|
||||
return LoopBlockExecutedResult(
|
||||
outputs_with_loop_values=outputs_with_loop_values,
|
||||
block_outputs=block_outputs,
|
||||
last_block=current_block,
|
||||
)
|
||||
|
||||
metadata: BlockMetadata = {
|
||||
"current_index": loop_idx,
|
||||
"current_value": loop_over_value,
|
||||
@@ -1515,6 +1607,38 @@ class ForLoopBlock(Block):
|
||||
)
|
||||
|
||||
if block_output.success or loop_block.continue_on_failure:
|
||||
next_label: str | None = None
|
||||
if loop_block.block_type == BlockType.CONDITIONAL:
|
||||
branch_metadata = (
|
||||
block_output.output_parameter_value
|
||||
if isinstance(block_output.output_parameter_value, dict)
|
||||
else None
|
||||
)
|
||||
next_label = (branch_metadata or {}).get("next_block_label")
|
||||
else:
|
||||
next_label = default_next_map.get(loop_block.label)
|
||||
|
||||
if not next_label:
|
||||
break
|
||||
|
||||
if next_label not in label_to_block:
|
||||
failure_block_result = await self.build_block_result(
|
||||
success=False,
|
||||
status=BlockStatus.failed,
|
||||
failure_reason=f"Next block label {next_label} not found inside loop {self.label}",
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
block_outputs.append(failure_block_result)
|
||||
outputs_with_loop_values.append(each_loop_output_values)
|
||||
return LoopBlockExecutedResult(
|
||||
outputs_with_loop_values=outputs_with_loop_values,
|
||||
block_outputs=block_outputs,
|
||||
last_block=current_block,
|
||||
)
|
||||
|
||||
current_label = next_label
|
||||
block_idx += 1
|
||||
continue
|
||||
|
||||
if loop_block.next_loop_on_failure or self.next_loop_on_failure:
|
||||
@@ -1528,6 +1652,8 @@ class ForLoopBlock(Block):
|
||||
)
|
||||
break
|
||||
|
||||
break
|
||||
|
||||
outputs_with_loop_values.append(each_loop_output_values)
|
||||
|
||||
return LoopBlockExecutedResult(
|
||||
@@ -1616,14 +1742,29 @@ class ForLoopBlock(Block):
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
loop_executed_result = await self.execute_loop_helper(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
workflow_run_context=workflow_run_context,
|
||||
loop_over_values=loop_over_values,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
try:
|
||||
loop_executed_result = await self.execute_loop_helper(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
workflow_run_context=workflow_run_context,
|
||||
loop_over_values=loop_over_values,
|
||||
organization_id=organization_id,
|
||||
browser_session_id=browser_session_id,
|
||||
)
|
||||
except InvalidWorkflowDefinition as exc:
|
||||
LOG.error(
|
||||
"Loop graph validation failed",
|
||||
error=str(exc),
|
||||
workflow_run_id=workflow_run_id,
|
||||
loop_label=self.label,
|
||||
)
|
||||
return await self.build_block_result(
|
||||
success=False,
|
||||
failure_reason=str(exc),
|
||||
status=BlockStatus.failed,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await self.record_output_parameter_value(
|
||||
workflow_run_context, workflow_run_id, loop_executed_result.outputs_with_loop_values
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user