Fix conditional evaluation using wrong value after template rendering SKY-7985 (#4801)
Co-authored-by: Suchintan Singh <suchintan@skyvern.com>
This commit is contained in:
@@ -5433,15 +5433,14 @@ class ConditionalBlock(Block):
|
|||||||
browser_session_id: str | None = None,
|
browser_session_id: str | None = None,
|
||||||
) -> tuple[list[bool], list[str], str | None, dict | None]:
|
) -> tuple[list[bool], list[str], str | None, dict | None]:
|
||||||
"""
|
"""
|
||||||
Evaluate natural language branch conditions using a single ExtractionBlock.
|
Evaluate natural language branch conditions in batch.
|
||||||
|
|
||||||
All prompt-based conditions are batched into ONE LLM call for performance.
|
All prompt-based conditions are batched into ONE LLM call for performance.
|
||||||
Jinja parts ({{ }}) are pre-rendered before sending to LLM.
|
Jinja parts ({{ }}) are pre-rendered before sending to LLM.
|
||||||
|
|
||||||
ExtractionBlock provides:
|
Evaluation strategy:
|
||||||
- Browser/page access for expressions like "comment count > 100"
|
- If any condition is pure natural language, use ExtractionBlock for browser/page context.
|
||||||
- UI visibility (shows up in workflow timeline with prompt/response)
|
- If all conditions contain Jinja and are pre-rendered, use direct LLM call (no browser context).
|
||||||
- Proper LLM integration with data_schema
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (results, rendered_expressions, extraction_goal, llm_response):
|
A tuple of (results, rendered_expressions, extraction_goal, llm_response):
|
||||||
@@ -5483,6 +5482,9 @@ class ConditionalBlock(Block):
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
rendered_expression = expression
|
rendered_expression = expression
|
||||||
|
# Rendering failed, so this expression is effectively unresolved and must
|
||||||
|
# take the ExtractionBlock path (with context) instead of direct LLM mode.
|
||||||
|
has_any_pure_natlang = True
|
||||||
else:
|
else:
|
||||||
rendered_expression = expression
|
rendered_expression = expression
|
||||||
has_any_pure_natlang = True
|
has_any_pure_natlang = True
|
||||||
@@ -5548,70 +5550,89 @@ class ConditionalBlock(Block):
|
|||||||
"required": ["evaluations"],
|
"required": ["evaluations"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Step 4: Create and execute single ExtractionBlock
|
|
||||||
output_param = OutputParameter(
|
|
||||||
output_parameter_id=str(uuid.uuid4()),
|
|
||||||
key=f"conditional_branch_eval_{generate_random_string()}",
|
|
||||||
workflow_id=self.output_parameter.workflow_id,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
modified_at=datetime.now(),
|
|
||||||
parameter_type=ParameterType.OUTPUT,
|
|
||||||
description=f"Conditional branch evaluation results ({len(branches)} conditions)",
|
|
||||||
)
|
|
||||||
|
|
||||||
extraction_block = ExtractionBlock(
|
|
||||||
label=f"conditional_branch_eval_{generate_random_string()}",
|
|
||||||
data_extraction_goal=extraction_goal,
|
|
||||||
data_schema=data_schema,
|
|
||||||
output_parameter=output_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
"Conditional branch ExtractionBlock created (batched)",
|
|
||||||
block_label=self.label,
|
|
||||||
num_conditions=len(branches),
|
|
||||||
extraction_goal_preview=extraction_goal[:500] if extraction_goal else None,
|
|
||||||
has_browser_session=browser_session_id is not None,
|
|
||||||
has_context=context_json is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extraction_result = await extraction_block.execute(
|
# Step 4: Evaluate conditions.
|
||||||
workflow_run_id=workflow_run_id,
|
if has_any_pure_natlang:
|
||||||
workflow_run_block_id=workflow_run_block_id,
|
output_param = OutputParameter(
|
||||||
organization_id=organization_id,
|
output_parameter_id=str(uuid.uuid4()),
|
||||||
browser_session_id=browser_session_id,
|
key=f"conditional_branch_eval_{generate_random_string()}",
|
||||||
)
|
workflow_id=self.output_parameter.workflow_id,
|
||||||
|
created_at=datetime.now(),
|
||||||
if not extraction_result.success:
|
modified_at=datetime.now(),
|
||||||
LOG.error(
|
parameter_type=ParameterType.OUTPUT,
|
||||||
"Conditional branch ExtractionBlock failed",
|
description=f"Conditional branch evaluation results ({len(branches)} conditions)",
|
||||||
block_label=self.label,
|
)
|
||||||
failure_reason=extraction_result.failure_reason,
|
extraction_block = ExtractionBlock(
|
||||||
|
label=f"conditional_branch_eval_{generate_random_string()}",
|
||||||
|
data_extraction_goal=extraction_goal,
|
||||||
|
data_schema=data_schema,
|
||||||
|
output_parameter=output_param,
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"Conditional branch ExtractionBlock created (batched)",
|
||||||
|
block_label=self.label,
|
||||||
|
num_conditions=len(branches),
|
||||||
|
extraction_goal_preview=extraction_goal[:500] if extraction_goal else None,
|
||||||
|
has_browser_session=browser_session_id is not None,
|
||||||
|
has_any_pure_natlang=has_any_pure_natlang,
|
||||||
|
using_browser_session=browser_session_id is not None,
|
||||||
|
has_context=context_json is not None,
|
||||||
|
)
|
||||||
|
extraction_result = await extraction_block.execute(
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
workflow_run_block_id=workflow_run_block_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
browser_session_id=browser_session_id,
|
||||||
)
|
)
|
||||||
raise ValueError(f"Branch evaluation failed: {extraction_result.failure_reason}")
|
|
||||||
|
|
||||||
# Record output parameter value if workflow context available
|
if not extraction_result.success:
|
||||||
if workflow_run_context:
|
LOG.error(
|
||||||
try:
|
"Conditional branch ExtractionBlock failed",
|
||||||
await extraction_block.record_output_parameter_value(
|
|
||||||
workflow_run_context=workflow_run_context,
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
value=extraction_result.output_parameter_value,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
LOG.warning(
|
|
||||||
"Failed to record conditional branch evaluation output",
|
|
||||||
workflow_run_id=workflow_run_id,
|
|
||||||
block_label=self.label,
|
block_label=self.label,
|
||||||
exc_info=True,
|
failure_reason=extraction_result.failure_reason,
|
||||||
)
|
)
|
||||||
|
raise ValueError(f"Branch evaluation failed: {extraction_result.failure_reason}")
|
||||||
|
|
||||||
|
if workflow_run_context:
|
||||||
|
try:
|
||||||
|
await extraction_block.record_output_parameter_value(
|
||||||
|
workflow_run_context=workflow_run_context,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
value=extraction_result.output_parameter_value,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.warning(
|
||||||
|
"Failed to record conditional branch evaluation output",
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
block_label=self.label,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_value = extraction_result.output_parameter_value
|
||||||
|
else:
|
||||||
|
# Do not use ExtractionBlock when every expression has already been Jinja-rendered.
|
||||||
|
# ExtractionBlock may still have page/browser context, which can cause the LLM to
|
||||||
|
# reinterpret resolved literals as on-screen references.
|
||||||
|
LOG.info(
|
||||||
|
"Conditional branch using direct LLM evaluation (no browser context)",
|
||||||
|
block_label=self.label,
|
||||||
|
num_conditions=len(branches),
|
||||||
|
extraction_goal_preview=extraction_goal[:500] if extraction_goal else None,
|
||||||
|
has_context=False,
|
||||||
|
)
|
||||||
|
output_value = await app.LLM_API_HANDLER(
|
||||||
|
prompt=extraction_goal,
|
||||||
|
prompt_name="conditional-prompt-branch-evaluation",
|
||||||
|
force_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Step 5: Extract the evaluation results (result + rendered_condition)
|
# Step 5: Extract the evaluation results (result + rendered_condition)
|
||||||
output_value = extraction_result.output_parameter_value
|
|
||||||
results_array: list[bool] = []
|
results_array: list[bool] = []
|
||||||
llm_rendered_expressions: list[str] = []
|
llm_rendered_expressions: list[str] = []
|
||||||
|
|
||||||
|
if isinstance(output_value, list):
|
||||||
|
output_value = {"evaluations": output_value}
|
||||||
|
|
||||||
if not isinstance(output_value, dict):
|
if not isinstance(output_value, dict):
|
||||||
raise ValueError(f"Unexpected output format: {type(output_value)}")
|
raise ValueError(f"Unexpected output format: {type(output_value)}")
|
||||||
|
|
||||||
@@ -5645,7 +5666,7 @@ class ConditionalBlock(Block):
|
|||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"Conditional branch ExtractionBlock execution failed",
|
"Conditional branch prompt evaluation failed",
|
||||||
block_label=self.label,
|
block_label=self.label,
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
|
|||||||
234
tests/unit/workflow/test_conditional_branch_evaluation.py
Normal file
234
tests/unit/workflow/test_conditional_branch_evaluation.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""Tests for prompt-based conditional branch evaluation behavior."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import skyvern.forge.sdk.workflow.models.block as block_module
|
||||||
|
from skyvern.forge.sdk.workflow.models.block import (
|
||||||
|
BranchCondition,
|
||||||
|
BranchEvaluationContext,
|
||||||
|
ConditionalBlock,
|
||||||
|
PromptBranchCriteria,
|
||||||
|
)
|
||||||
|
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter
|
||||||
|
from skyvern.schemas.workflows import BlockResult
|
||||||
|
|
||||||
|
|
||||||
|
def _output_parameter(key: str) -> OutputParameter:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
return OutputParameter(
|
||||||
|
output_parameter_id=f"{key}_id",
|
||||||
|
key=key,
|
||||||
|
workflow_id="wf",
|
||||||
|
created_at=now,
|
||||||
|
modified_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _conditional_block() -> ConditionalBlock:
|
||||||
|
return ConditionalBlock(
|
||||||
|
label="cond",
|
||||||
|
output_parameter=_output_parameter("conditional_output"),
|
||||||
|
branch_conditions=[
|
||||||
|
BranchCondition(criteria=PromptBranchCriteria(expression="fallback"), next_block_label="next"),
|
||||||
|
BranchCondition(is_default=True, next_block_label=None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extraction_result(output_parameter: OutputParameter, evaluations: list[dict]) -> BlockResult:
|
||||||
|
return BlockResult(
|
||||||
|
success=True,
|
||||||
|
output_parameter=output_parameter,
|
||||||
|
output_parameter_value={"evaluations": evaluations},
|
||||||
|
failure_reason=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_jinja_rendered_prompt_condition_omits_browser_session() -> None:
|
||||||
|
block = _conditional_block()
|
||||||
|
branch = BranchCondition(
|
||||||
|
criteria=PromptBranchCriteria(expression='{{Single_or_Joint__c}} == "Joint"'),
|
||||||
|
next_block_label="joint",
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluation_context = BranchEvaluationContext(
|
||||||
|
workflow_run_context=None,
|
||||||
|
template_renderer=lambda expr: expr.replace("{{Single_or_Joint__c}}", "Joint"),
|
||||||
|
)
|
||||||
|
evaluation_context.build_llm_safe_context_snapshot = MagicMock(return_value={"Single_or_Joint__c": "Joint"}) # type: ignore[method-assign]
|
||||||
|
mock_llm_handler = AsyncMock()
|
||||||
|
mock_llm_handler.return_value = {
|
||||||
|
"evaluations": [{"rendered_condition": 'Joint == "Joint"', "reasoning": "ok", "result": True}]
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.dict(block_module.app.__dict__, {"LLM_API_HANDLER": mock_llm_handler}),
|
||||||
|
patch("skyvern.forge.sdk.workflow.models.block.prompt_engine.load_prompt", return_value="goal") as mock_prompt,
|
||||||
|
patch("skyvern.forge.sdk.workflow.models.block.ExtractionBlock") as mock_extraction_cls,
|
||||||
|
):
|
||||||
|
results, rendered_expressions, _, llm_response = await block._evaluate_prompt_branches(
|
||||||
|
branches=[branch],
|
||||||
|
evaluation_context=evaluation_context,
|
||||||
|
workflow_run_id="wr_test",
|
||||||
|
workflow_run_block_id="wrb_test",
|
||||||
|
organization_id="org_test",
|
||||||
|
browser_session_id="bs_test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results == [True]
|
||||||
|
assert rendered_expressions == ['Joint == "Joint"']
|
||||||
|
assert llm_response == {
|
||||||
|
"evaluations": [{"rendered_condition": 'Joint == "Joint"', "reasoning": "ok", "result": True}]
|
||||||
|
}
|
||||||
|
mock_llm_handler.assert_awaited_once_with(
|
||||||
|
prompt="goal",
|
||||||
|
prompt_name="conditional-prompt-branch-evaluation",
|
||||||
|
force_dict=True,
|
||||||
|
)
|
||||||
|
mock_extraction_cls.assert_not_called()
|
||||||
|
evaluation_context.build_llm_safe_context_snapshot.assert_not_called() # type: ignore[attr-defined]
|
||||||
|
assert mock_prompt.call_args.kwargs["context_json"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pure_natlang_prompt_condition_uses_browser_session_and_context() -> None:
|
||||||
|
block = _conditional_block()
|
||||||
|
branch = BranchCondition(
|
||||||
|
criteria=PromptBranchCriteria(expression="user selected premium plan"),
|
||||||
|
next_block_label="premium",
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluation_context = BranchEvaluationContext(workflow_run_context=None, template_renderer=lambda expr: expr)
|
||||||
|
evaluation_context.build_llm_safe_context_snapshot = MagicMock(return_value={"plan": "premium"}) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("skyvern.forge.sdk.workflow.models.block.prompt_engine.load_prompt", return_value="goal") as mock_prompt,
|
||||||
|
patch("skyvern.forge.sdk.workflow.models.block.ExtractionBlock") as mock_extraction_cls,
|
||||||
|
):
|
||||||
|
mock_extraction = MagicMock()
|
||||||
|
mock_extraction.execute = AsyncMock(
|
||||||
|
return_value=_extraction_result(
|
||||||
|
block.output_parameter,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"rendered_condition": "user selected premium plan",
|
||||||
|
"reasoning": "ok",
|
||||||
|
"result": True,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_extraction_cls.return_value = mock_extraction
|
||||||
|
|
||||||
|
await block._evaluate_prompt_branches(
|
||||||
|
branches=[branch],
|
||||||
|
evaluation_context=evaluation_context,
|
||||||
|
workflow_run_id="wr_test",
|
||||||
|
workflow_run_block_id="wrb_test",
|
||||||
|
organization_id="org_test",
|
||||||
|
browser_session_id="bs_test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_extraction.execute.call_args.kwargs["browser_session_id"] == "bs_test"
|
||||||
|
evaluation_context.build_llm_safe_context_snapshot.assert_called_once() # type: ignore[attr-defined]
|
||||||
|
assert mock_prompt.call_args.kwargs["context_json"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mixed_prompt_conditions_keep_browser_session() -> None:
|
||||||
|
block = _conditional_block()
|
||||||
|
branches = [
|
||||||
|
BranchCondition(
|
||||||
|
criteria=PromptBranchCriteria(expression="{{var}} == 'value'"),
|
||||||
|
next_block_label="jinja_branch",
|
||||||
|
),
|
||||||
|
BranchCondition(
|
||||||
|
criteria=PromptBranchCriteria(expression="user selected premium plan"),
|
||||||
|
next_block_label="natlang_branch",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
evaluation_context = BranchEvaluationContext(
|
||||||
|
workflow_run_context=None,
|
||||||
|
template_renderer=lambda expr: expr.replace("{{var}}", "value"),
|
||||||
|
)
|
||||||
|
evaluation_context.build_llm_safe_context_snapshot = MagicMock(return_value={"var": "value"}) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with patch("skyvern.forge.sdk.workflow.models.block.ExtractionBlock") as mock_extraction_cls:
|
||||||
|
mock_extraction = MagicMock()
|
||||||
|
mock_extraction.execute = AsyncMock(
|
||||||
|
return_value=_extraction_result(
|
||||||
|
block.output_parameter,
|
||||||
|
[
|
||||||
|
{"rendered_condition": "value == 'value'", "reasoning": "ok", "result": True},
|
||||||
|
{
|
||||||
|
"rendered_condition": "user selected premium plan",
|
||||||
|
"reasoning": "ok",
|
||||||
|
"result": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_extraction_cls.return_value = mock_extraction
|
||||||
|
|
||||||
|
await block._evaluate_prompt_branches(
|
||||||
|
branches=branches,
|
||||||
|
evaluation_context=evaluation_context,
|
||||||
|
workflow_run_id="wr_test",
|
||||||
|
workflow_run_block_id="wrb_test",
|
||||||
|
organization_id="org_test",
|
||||||
|
browser_session_id="bs_test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_extraction.execute.call_args.kwargs["browser_session_id"] == "bs_test"
|
||||||
|
evaluation_context.build_llm_safe_context_snapshot.assert_called_once() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_jinja_render_failure_falls_back_to_extraction_block() -> None:
|
||||||
|
block = _conditional_block()
|
||||||
|
branch = BranchCondition(
|
||||||
|
criteria=PromptBranchCriteria(expression='{{Single_or_Joint__c}} == "Joint"'),
|
||||||
|
next_block_label="joint",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _raise_render_error(_: str) -> str:
|
||||||
|
raise RuntimeError("render failed")
|
||||||
|
|
||||||
|
evaluation_context = BranchEvaluationContext(
|
||||||
|
workflow_run_context=None,
|
||||||
|
template_renderer=_raise_render_error,
|
||||||
|
)
|
||||||
|
evaluation_context.build_llm_safe_context_snapshot = MagicMock(return_value={"Single_or_Joint__c": "Joint"}) # type: ignore[method-assign]
|
||||||
|
mock_llm_handler = AsyncMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.dict(block_module.app.__dict__, {"LLM_API_HANDLER": mock_llm_handler}),
|
||||||
|
patch("skyvern.forge.sdk.workflow.models.block.ExtractionBlock") as mock_extraction_cls,
|
||||||
|
):
|
||||||
|
mock_extraction = MagicMock()
|
||||||
|
mock_extraction.execute = AsyncMock(
|
||||||
|
return_value=_extraction_result(
|
||||||
|
block.output_parameter,
|
||||||
|
[{"rendered_condition": '{{Single_or_Joint__c}} == "Joint"', "reasoning": "ok", "result": False}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_extraction_cls.return_value = mock_extraction
|
||||||
|
|
||||||
|
await block._evaluate_prompt_branches(
|
||||||
|
branches=[branch],
|
||||||
|
evaluation_context=evaluation_context,
|
||||||
|
workflow_run_id="wr_test",
|
||||||
|
workflow_run_block_id="wrb_test",
|
||||||
|
organization_id="org_test",
|
||||||
|
browser_session_id="bs_test",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_extraction.execute.assert_awaited_once()
|
||||||
|
assert mock_extraction.execute.call_args.kwargs["browser_session_id"] == "bs_test"
|
||||||
|
mock_llm_handler.assert_not_called()
|
||||||
|
evaluation_context.build_llm_safe_context_snapshot.assert_called_once() # type: ignore[attr-defined]
|
||||||
Reference in New Issue
Block a user