Files
Dorod-Sky/tests/unit/workflow/test_dag_engine.py
2026-02-12 20:43:27 -08:00

250 lines
9.6 KiB
Python

from datetime import UTC, datetime
from unittest.mock import AsyncMock
import pytest
from skyvern.forge import app
from skyvern.forge.sdk.workflow.exceptions import InvalidWorkflowDefinition
from skyvern.forge.sdk.workflow.models.block import (
BranchCondition,
ConditionalBlock,
ExtractionBlock,
JinjaBranchCriteria,
NavigationBlock,
PromptBranchCriteria,
)
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
from skyvern.forge.sdk.workflow.service import WorkflowService
from skyvern.schemas.workflows import BlockStatus
def _output_parameter(key: str) -> OutputParameter:
now = datetime.now(UTC)
return OutputParameter(
output_parameter_id=f"{key}_id",
key=key,
workflow_id="wf",
created_at=now,
modified_at=now,
)
def _navigation_block(label: str, next_block_label: str | None = None) -> NavigationBlock:
return NavigationBlock(
url="https://example.com",
label=label,
title=label,
navigation_goal="goal",
output_parameter=_output_parameter(f"{label}_output"),
next_block_label=next_block_label,
)
def _extraction_block(label: str, next_block_label: str | None = None) -> ExtractionBlock:
return ExtractionBlock(
url="https://example.com",
label=label,
title=label,
data_extraction_goal="extract data",
output_parameter=_output_parameter(f"{label}_output"),
next_block_label=next_block_label,
)
def _conditional_block(
label: str, branch_conditions: list[BranchCondition], next_block_label: str | None = None
) -> ConditionalBlock:
return ConditionalBlock(
label=label,
output_parameter=_output_parameter(f"{label}_output"),
branch_conditions=branch_conditions,
next_block_label=next_block_label,
)
class DummyContext:
def __init__(self, workflow_run_id: str) -> None:
self.blocks_metadata: dict[str, dict] = {}
self.values: dict[str, object] = {}
self.secrets: dict[str, object] = {}
self.parameters: dict[str, object] = {}
self.workflow_run_outputs: dict[str, object] = {}
self.include_secrets_in_templates = False
self.workflow_title = "test"
self.workflow_id = "wf"
self.workflow_permanent_id = "wf-perm"
self.workflow_run_id = workflow_run_id
def update_block_metadata(self, label: str, metadata: dict) -> None:
self.blocks_metadata[label] = metadata
def get_block_metadata(self, label: str | None) -> dict:
if label is None:
return {}
return self.blocks_metadata.get(label, {})
def mask_secrets_in_data(self, data: object) -> object:
"""Mock method - returns data as-is since no secrets in tests."""
return data
async def register_output_parameter_value_post_execution(self, parameter: OutputParameter, value: object) -> None: # noqa: ARG002
return None
def build_workflow_run_summary(self) -> dict:
return {}
def test_build_workflow_graph_infers_default_edges() -> None:
service = WorkflowService()
first = _navigation_block("first")
second = _navigation_block("second")
start_label, label_to_block, default_next_map = service._build_workflow_graph([first, second])
assert start_label == "first"
assert set(label_to_block.keys()) == {"first", "second"}
assert default_next_map["first"] == "second"
assert default_next_map["second"] is None
def test_build_workflow_graph_rejects_cycles() -> None:
service = WorkflowService()
first = _navigation_block("first", next_block_label="second")
second = _navigation_block("second", next_block_label="first")
with pytest.raises(InvalidWorkflowDefinition):
service._build_workflow_graph([first, second])
def test_build_workflow_graph_requires_single_root() -> None:
service = WorkflowService()
first = _navigation_block("first")
second = _navigation_block("second")
with pytest.raises(InvalidWorkflowDefinition):
service._build_workflow_graph([first, second, _navigation_block("third", next_block_label="second")])
def test_build_workflow_graph_conditional_blocks_no_sequential_defaulting() -> None:
"""
Test that workflows with conditional blocks do not apply sequential defaulting.
This prevents cycles when blocks are ordered differently than execution order.
For example, if a terminal block appears before branch targets in the blocks array,
sequential defaulting would incorrectly create a cycle.
"""
service = WorkflowService()
# Simulate a workflow where execution order differs from block array order
# Execution: start -> extract -> conditional -> (branch_a OR branch_b) -> terminal
# Array order: [start, extract, conditional, terminal, branch_a, branch_b]
start = _navigation_block("start", next_block_label="extract")
extract = _extraction_block("extract", next_block_label="conditional")
conditional = _conditional_block(
"conditional",
branch_conditions=[
BranchCondition(
criteria=JinjaBranchCriteria(expression="{{ true }}"), next_block_label="branch_a", is_default=False
),
BranchCondition(criteria=None, next_block_label="branch_b", is_default=True),
],
next_block_label="terminal", # This should be ignored for conditional blocks
)
terminal = _extraction_block("terminal", next_block_label=None) # Terminal block with explicit None
branch_a = _navigation_block("branch_a", next_block_label="terminal")
branch_b = _navigation_block("branch_b", next_block_label="terminal")
# Block array has terminal before branch_a and branch_b
blocks = [start, extract, conditional, terminal, branch_a, branch_b]
# This should succeed without creating a cycle
start_label, label_to_block, default_next_map = service._build_workflow_graph(blocks)
assert start_label == "start"
assert set(label_to_block.keys()) == {"start", "extract", "conditional", "terminal", "branch_a", "branch_b"}
# Verify that sequential defaulting was NOT applied
# terminal should remain None, not be defaulted to branch_a
assert default_next_map["terminal"] is None
assert default_next_map["branch_a"] == "terminal"
assert default_next_map["branch_b"] == "terminal"
@pytest.mark.asyncio
async def test_evaluate_conditional_block_records_branch_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
output_param = _output_parameter("conditional_output")
block = ConditionalBlock(
label="cond",
output_parameter=output_param,
branch_conditions=[
BranchCondition(criteria=JinjaBranchCriteria(expression="{{ flag }}"), next_block_label="next"),
BranchCondition(is_default=True, next_block_label=None),
],
)
ctx = DummyContext(workflow_run_id="run-1")
ctx.values["flag"] = True
monkeypatch.setattr(app.WORKFLOW_CONTEXT_MANAGER, "get_workflow_run_context", lambda workflow_run_id: ctx)
app.DATABASE.update_workflow_run_block.reset_mock()
app.DATABASE.create_or_update_workflow_run_output_parameter.reset_mock()
result = await block.execute(
workflow_run_id="run-1",
workflow_run_block_id="wrb-1",
organization_id="org-1",
)
metadata = result.output_parameter_value
assert metadata["branch_taken"] == "next"
assert metadata["next_block_label"] == "next"
assert result.status == BlockStatus.completed
assert ctx.blocks_metadata["cond"]["branch_taken"] == "next"
# Get the actual call arguments
call_args = app.DATABASE.update_workflow_run_block.call_args
assert call_args.kwargs["workflow_run_block_id"] == "wrb-1"
assert call_args.kwargs["output"] == metadata
assert call_args.kwargs["status"] == BlockStatus.completed
assert call_args.kwargs["failure_reason"] is None
assert call_args.kwargs["organization_id"] == "org-1"
# Verify the new execution tracking fields are present
assert call_args.kwargs["executed_branch_expression"] == "{{ flag }}"
assert call_args.kwargs["executed_branch_result"] is True
assert call_args.kwargs["executed_branch_next_block"] == "next"
# executed_branch_id should be a UUID string
assert isinstance(call_args.kwargs["executed_branch_id"], str)
@pytest.mark.asyncio
async def test_prompt_branch_uses_batched_evaluation(monkeypatch: pytest.MonkeyPatch) -> None:
output_param = _output_parameter("conditional_output_prompt")
prompt_branch = BranchCondition(
criteria=PromptBranchCriteria(expression="Check if urgent"), next_block_label="next"
)
default_branch = BranchCondition(is_default=True, next_block_label=None)
block = ConditionalBlock(
label="cond_prompt",
output_parameter=output_param,
branch_conditions=[prompt_branch, default_branch],
)
ctx = DummyContext(workflow_run_id="run-2")
monkeypatch.setattr(app.WORKFLOW_CONTEXT_MANAGER, "get_workflow_run_context", lambda workflow_run_id: ctx)
# Return tuple: (results, rendered_expressions, extraction_goal, llm_response)
prompt_eval_mock = AsyncMock(return_value=([True], ["Check if urgent"], "test prompt", None))
monkeypatch.setattr(ConditionalBlock, "_evaluate_prompt_branches", prompt_eval_mock)
result = await block.execute(
workflow_run_id="run-2",
workflow_run_block_id="wrb-2",
organization_id="org-2",
)
assert result.status == BlockStatus.completed
metadata = result.output_parameter_value
assert metadata["branch_taken"] == "next"
assert metadata["criteria_type"] == "prompt"
prompt_eval_mock.assert_awaited_once()