diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 65676195..7a4f322e 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -5433,15 +5433,14 @@ class ConditionalBlock(Block): browser_session_id: str | None = None, ) -> tuple[list[bool], list[str], str | None, dict | None]: """ - Evaluate natural language branch conditions using a single ExtractionBlock. + Evaluate natural language branch conditions in batch. All prompt-based conditions are batched into ONE LLM call for performance. Jinja parts ({{ }}) are pre-rendered before sending to LLM. - ExtractionBlock provides: - - Browser/page access for expressions like "comment count > 100" - - UI visibility (shows up in workflow timeline with prompt/response) - - Proper LLM integration with data_schema + Evaluation strategy: + - If any condition is pure natural language, use ExtractionBlock for browser/page context. + - If all conditions contain Jinja and are pre-rendered, use direct LLM call (no browser context). Returns: A tuple of (results, rendered_expressions, extraction_goal, llm_response): @@ -5483,6 +5482,9 @@ class ConditionalBlock(Block): exc_info=True, ) rendered_expression = expression + # Rendering failed, so this expression is effectively unresolved and must + # take the ExtractionBlock path (with context) instead of direct LLM mode. + has_any_pure_natlang = True else: rendered_expression = expression has_any_pure_natlang = True @@ -5548,70 +5550,89 @@ class ConditionalBlock(Block): "required": ["evaluations"], } - # Step 4: Create and execute single ExtractionBlock - output_param = OutputParameter( - output_parameter_id=str(uuid.uuid4()), - key=f"conditional_branch_eval_{generate_random_string()}", - workflow_id=self.output_parameter.workflow_id, - created_at=datetime.now(), - modified_at=datetime.now(), - parameter_type=ParameterType.OUTPUT, - description=f"Conditional branch evaluation results ({len(branches)} conditions)", - ) - - extraction_block = ExtractionBlock( - label=f"conditional_branch_eval_{generate_random_string()}", - data_extraction_goal=extraction_goal, - data_schema=data_schema, - output_parameter=output_param, - ) - - LOG.info( - "Conditional branch ExtractionBlock created (batched)", - block_label=self.label, - num_conditions=len(branches), - extraction_goal_preview=extraction_goal[:500] if extraction_goal else None, - has_browser_session=browser_session_id is not None, - has_context=context_json is not None, - ) - try: - extraction_result = await extraction_block.execute( - workflow_run_id=workflow_run_id, - workflow_run_block_id=workflow_run_block_id, - organization_id=organization_id, - browser_session_id=browser_session_id, - ) - - if not extraction_result.success: - LOG.error( - "Conditional branch ExtractionBlock failed", - block_label=self.label, - failure_reason=extraction_result.failure_reason, + # Step 4: Evaluate conditions. + if has_any_pure_natlang: + output_param = OutputParameter( + output_parameter_id=str(uuid.uuid4()), + key=f"conditional_branch_eval_{generate_random_string()}", + workflow_id=self.output_parameter.workflow_id, + created_at=datetime.now(), + modified_at=datetime.now(), + parameter_type=ParameterType.OUTPUT, + description=f"Conditional branch evaluation results ({len(branches)} conditions)", + ) + extraction_block = ExtractionBlock( + label=f"conditional_branch_eval_{generate_random_string()}", + data_extraction_goal=extraction_goal, + data_schema=data_schema, + output_parameter=output_param, + ) + LOG.info( + "Conditional branch ExtractionBlock created (batched)", + block_label=self.label, + num_conditions=len(branches), + extraction_goal_preview=extraction_goal[:500] if extraction_goal else None, + has_browser_session=browser_session_id is not None, + has_any_pure_natlang=has_any_pure_natlang, + using_browser_session=browser_session_id is not None, + has_context=context_json is not None, + ) + extraction_result = await extraction_block.execute( + workflow_run_id=workflow_run_id, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, + browser_session_id=browser_session_id, ) - raise ValueError(f"Branch evaluation failed: {extraction_result.failure_reason}") - # Record output parameter value if workflow context available - if workflow_run_context: - try: - await extraction_block.record_output_parameter_value( - workflow_run_context=workflow_run_context, - workflow_run_id=workflow_run_id, - value=extraction_result.output_parameter_value, - ) - except Exception: - LOG.warning( - "Failed to record conditional branch evaluation output", - workflow_run_id=workflow_run_id, + if not extraction_result.success: + LOG.error( + "Conditional branch ExtractionBlock failed", block_label=self.label, - exc_info=True, + failure_reason=extraction_result.failure_reason, ) + raise ValueError(f"Branch evaluation failed: {extraction_result.failure_reason}") + + if workflow_run_context: + try: + await extraction_block.record_output_parameter_value( + workflow_run_context=workflow_run_context, + workflow_run_id=workflow_run_id, + value=extraction_result.output_parameter_value, + ) + except Exception: + LOG.warning( + "Failed to record conditional branch evaluation output", + workflow_run_id=workflow_run_id, + block_label=self.label, + exc_info=True, + ) + + output_value = extraction_result.output_parameter_value + else: + # Do not use ExtractionBlock when every expression has already been Jinja-rendered. + # ExtractionBlock may still have page/browser context, which can cause the LLM to + # reinterpret resolved literals as on-screen references. + LOG.info( + "Conditional branch using direct LLM evaluation (no browser context)", + block_label=self.label, + num_conditions=len(branches), + extraction_goal_preview=extraction_goal[:500] if extraction_goal else None, + has_context=False, + ) + output_value = await app.LLM_API_HANDLER( + prompt=extraction_goal, + prompt_name="conditional-prompt-branch-evaluation", + force_dict=True, + ) # Step 5: Extract the evaluation results (result + rendered_condition) - output_value = extraction_result.output_parameter_value results_array: list[bool] = [] llm_rendered_expressions: list[str] = [] + if isinstance(output_value, list): + output_value = {"evaluations": output_value} + if not isinstance(output_value, dict): raise ValueError(f"Unexpected output format: {type(output_value)}") @@ -5645,7 +5666,7 @@ class ConditionalBlock(Block): except Exception as exc: LOG.error( - "Conditional branch ExtractionBlock execution failed", + "Conditional branch prompt evaluation failed", block_label=self.label, error=str(exc), exc_info=True, diff --git a/tests/unit/workflow/test_conditional_branch_evaluation.py b/tests/unit/workflow/test_conditional_branch_evaluation.py new file mode 100644 index 00000000..d8cf5d77 --- /dev/null +++ b/tests/unit/workflow/test_conditional_branch_evaluation.py @@ -0,0 +1,234 @@ +"""Tests for prompt-based conditional branch evaluation behavior.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import skyvern.forge.sdk.workflow.models.block as block_module +from skyvern.forge.sdk.workflow.models.block import ( + BranchCondition, + BranchEvaluationContext, + ConditionalBlock, + PromptBranchCriteria, +) +from skyvern.forge.sdk.workflow.models.parameter import OutputParameter +from skyvern.schemas.workflows import BlockResult + + +def _output_parameter(key: str) -> OutputParameter: + now = datetime.now(UTC) + return OutputParameter( + output_parameter_id=f"{key}_id", + key=key, + workflow_id="wf", + created_at=now, + modified_at=now, + ) + + +def _conditional_block() -> ConditionalBlock: + return ConditionalBlock( + label="cond", + output_parameter=_output_parameter("conditional_output"), + branch_conditions=[ + BranchCondition(criteria=PromptBranchCriteria(expression="fallback"), next_block_label="next"), + BranchCondition(is_default=True, next_block_label=None), + ], + ) + + +def _extraction_result(output_parameter: OutputParameter, evaluations: list[dict]) -> BlockResult: + return BlockResult( + success=True, + output_parameter=output_parameter, + output_parameter_value={"evaluations": evaluations}, + failure_reason=None, + ) + + +@pytest.mark.asyncio +async def test_jinja_rendered_prompt_condition_omits_browser_session() -> None: + block = _conditional_block() + branch = BranchCondition( + criteria=PromptBranchCriteria(expression='{{Single_or_Joint__c}} == "Joint"'), + next_block_label="joint", + ) + + evaluation_context = BranchEvaluationContext( + workflow_run_context=None, + template_renderer=lambda expr: expr.replace("{{Single_or_Joint__c}}", "Joint"), + ) + evaluation_context.build_llm_safe_context_snapshot = MagicMock(return_value={"Single_or_Joint__c": "Joint"}) # type: ignore[method-assign] + mock_llm_handler = AsyncMock() + mock_llm_handler.return_value = { + "evaluations": [{"rendered_condition": 'Joint == "Joint"', "reasoning": "ok", "result": True}] + } + + with ( + patch.dict(block_module.app.__dict__, {"LLM_API_HANDLER": mock_llm_handler}), + patch("skyvern.forge.sdk.workflow.models.block.prompt_engine.load_prompt", return_value="goal") as mock_prompt, + patch("skyvern.forge.sdk.workflow.models.block.ExtractionBlock") as mock_extraction_cls, + ): + results, rendered_expressions, _, llm_response = await block._evaluate_prompt_branches( + branches=[branch], + evaluation_context=evaluation_context, + workflow_run_id="wr_test", + workflow_run_block_id="wrb_test", + organization_id="org_test", + browser_session_id="bs_test", + ) + + assert results == [True] + assert rendered_expressions == ['Joint == "Joint"'] + assert llm_response == { + "evaluations": [{"rendered_condition": 'Joint == "Joint"', "reasoning": "ok", "result": True}] + } + mock_llm_handler.assert_awaited_once_with( + prompt="goal", + prompt_name="conditional-prompt-branch-evaluation", + force_dict=True, + ) + mock_extraction_cls.assert_not_called() + evaluation_context.build_llm_safe_context_snapshot.assert_not_called() # type: ignore[attr-defined] + assert mock_prompt.call_args.kwargs["context_json"] is None + + +@pytest.mark.asyncio +async def test_pure_natlang_prompt_condition_uses_browser_session_and_context() -> None: + block = _conditional_block() + branch = BranchCondition( + criteria=PromptBranchCriteria(expression="user selected premium plan"), + next_block_label="premium", + ) + + evaluation_context = BranchEvaluationContext(workflow_run_context=None, template_renderer=lambda expr: expr) + evaluation_context.build_llm_safe_context_snapshot = MagicMock(return_value={"plan": "premium"}) # type: ignore[method-assign] + + with ( + patch("skyvern.forge.sdk.workflow.models.block.prompt_engine.load_prompt", return_value="goal") as mock_prompt, + patch("skyvern.forge.sdk.workflow.models.block.ExtractionBlock") as mock_extraction_cls, + ): + mock_extraction = MagicMock() + mock_extraction.execute = AsyncMock( + return_value=_extraction_result( + block.output_parameter, + [ + { + "rendered_condition": "user selected premium plan", + "reasoning": "ok", + "result": True, + } + ], + ) + ) + mock_extraction_cls.return_value = mock_extraction + + await block._evaluate_prompt_branches( + branches=[branch], + evaluation_context=evaluation_context, + workflow_run_id="wr_test", + workflow_run_block_id="wrb_test", + organization_id="org_test", + browser_session_id="bs_test", + ) + + assert mock_extraction.execute.call_args.kwargs["browser_session_id"] == "bs_test" + evaluation_context.build_llm_safe_context_snapshot.assert_called_once() # type: ignore[attr-defined] + assert mock_prompt.call_args.kwargs["context_json"] is not None + + +@pytest.mark.asyncio +async def test_mixed_prompt_conditions_keep_browser_session() -> None: + block = _conditional_block() + branches = [ + BranchCondition( + criteria=PromptBranchCriteria(expression="{{var}} == 'value'"), + next_block_label="jinja_branch", + ), + BranchCondition( + criteria=PromptBranchCriteria(expression="user selected premium plan"), + next_block_label="natlang_branch", + ), + ] + + evaluation_context = BranchEvaluationContext( + workflow_run_context=None, + template_renderer=lambda expr: expr.replace("{{var}}", "value"), + ) + evaluation_context.build_llm_safe_context_snapshot = MagicMock(return_value={"var": "value"}) # type: ignore[method-assign] + + with patch("skyvern.forge.sdk.workflow.models.block.ExtractionBlock") as mock_extraction_cls: + mock_extraction = MagicMock() + mock_extraction.execute = AsyncMock( + return_value=_extraction_result( + block.output_parameter, + [ + {"rendered_condition": "value == 'value'", "reasoning": "ok", "result": True}, + { + "rendered_condition": "user selected premium plan", + "reasoning": "ok", + "result": False, + }, + ], + ) + ) + mock_extraction_cls.return_value = mock_extraction + + await block._evaluate_prompt_branches( + branches=branches, + evaluation_context=evaluation_context, + workflow_run_id="wr_test", + workflow_run_block_id="wrb_test", + organization_id="org_test", + browser_session_id="bs_test", + ) + + assert mock_extraction.execute.call_args.kwargs["browser_session_id"] == "bs_test" + evaluation_context.build_llm_safe_context_snapshot.assert_called_once() # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_jinja_render_failure_falls_back_to_extraction_block() -> None: + block = _conditional_block() + branch = BranchCondition( + criteria=PromptBranchCriteria(expression='{{Single_or_Joint__c}} == "Joint"'), + next_block_label="joint", + ) + + def _raise_render_error(_: str) -> str: + raise RuntimeError("render failed") + + evaluation_context = BranchEvaluationContext( + workflow_run_context=None, + template_renderer=_raise_render_error, + ) + evaluation_context.build_llm_safe_context_snapshot = MagicMock(return_value={"Single_or_Joint__c": "Joint"}) # type: ignore[method-assign] + mock_llm_handler = AsyncMock() + + with ( + patch.dict(block_module.app.__dict__, {"LLM_API_HANDLER": mock_llm_handler}), + patch("skyvern.forge.sdk.workflow.models.block.ExtractionBlock") as mock_extraction_cls, + ): + mock_extraction = MagicMock() + mock_extraction.execute = AsyncMock( + return_value=_extraction_result( + block.output_parameter, + [{"rendered_condition": '{{Single_or_Joint__c}} == "Joint"', "reasoning": "ok", "result": False}], + ) + ) + mock_extraction_cls.return_value = mock_extraction + + await block._evaluate_prompt_branches( + branches=[branch], + evaluation_context=evaluation_context, + workflow_run_id="wr_test", + workflow_run_block_id="wrb_test", + organization_id="org_test", + browser_session_id="bs_test", + ) + + mock_extraction.execute.assert_awaited_once() + assert mock_extraction.execute.call_args.kwargs["browser_session_id"] == "bs_test" + mock_llm_handler.assert_not_called() + evaluation_context.build_llm_safe_context_snapshot.assert_called_once() # type: ignore[attr-defined]