ConditionalBlock spec Update + Implementation for BranchCondition and BranchCriteria (#4120)

This commit is contained in:
Shuchang Zheng
2025-11-27 11:52:37 -08:00
committed by GitHub
parent 5b530cab52
commit ea40b64fdc

View File

@@ -3819,19 +3819,81 @@ class HttpRequestBlock(Block):
) )
class BranchEvaluationContext(BaseModel): class BranchEvaluationContext:
"""Collection of runtime data that BranchCriteria evaluators can consume.""" """Collection of runtime data that BranchCriteria evaluators can consume."""
workflow_parameters: dict[str, Any] = Field(default_factory=dict) def __init__(
block_outputs: dict[str, Any] = Field(default_factory=dict) self,
environment: dict[str, Any] | None = None *,
llm_results: dict[str, Any] | None = None workflow_run_context: WorkflowRunContext | None = None,
block_label: str | None = None,
) -> None:
self.workflow_run_context = workflow_run_context
self.block_label = block_label
def build_template_data(self) -> dict[str, Any]:
"""Build Jinja template data mirroring block parameter rendering context."""
if self.workflow_run_context is None:
return {
"params": {},
"outputs": {},
"environment": {},
"env": {},
"llm": {},
}
ctx = self.workflow_run_context
template_data = ctx.values.copy()
if ctx.include_secrets_in_templates:
template_data.update(ctx.secrets)
credential_params: list[tuple[str, dict[str, Any]]] = []
for key, value in template_data.items():
if isinstance(value, dict) and "context" in value and "username" in value and "password" in value:
credential_params.append((key, value))
for key, value in credential_params:
username_secret_id = value.get("username", "")
password_secret_id = value.get("password", "")
real_username = template_data.get(username_secret_id, "")
real_password = template_data.get(password_secret_id, "")
template_data[f"{key}_real_username"] = real_username
template_data[f"{key}_real_password"] = real_password
if self.block_label:
block_reference_data: dict[str, Any] = ctx.get_block_metadata(self.block_label)
if self.block_label in template_data:
current_value = template_data[self.block_label]
if isinstance(current_value, dict):
block_reference_data.update(current_value)
template_data[self.block_label] = block_reference_data
if "current_index" in block_reference_data:
template_data["current_index"] = block_reference_data["current_index"]
if "current_item" in block_reference_data:
template_data["current_item"] = block_reference_data["current_item"]
if "current_value" in block_reference_data:
template_data["current_value"] = block_reference_data["current_value"]
template_data.setdefault("workflow_title", ctx.workflow_title)
template_data.setdefault("workflow_id", ctx.workflow_id)
template_data.setdefault("workflow_permanent_id", ctx.workflow_permanent_id)
template_data.setdefault("workflow_run_id", ctx.workflow_run_id)
template_data.setdefault("params", template_data.get("params", {}))
template_data.setdefault("outputs", template_data.get("outputs", {}))
template_data.setdefault("environment", template_data.get("environment", {}))
template_data.setdefault("env", template_data.get("environment"))
template_data.setdefault("llm", template_data.get("llm", {}))
return template_data
class BranchCriteria(BaseModel, abc.ABC): class BranchCriteria(BaseModel, abc.ABC):
"""Abstract interface describing how a branch condition should be evaluated.""" """Abstract interface describing how a branch condition should be evaluated."""
criteria_type: str criteria_type: str
expression: str
description: str | None = None description: str | None = None
@abc.abstractmethod @abc.abstractmethod
@@ -3844,26 +3906,61 @@ class BranchCriteria(BaseModel, abc.ABC):
return False return False
class JinjaBranchCriteria(BranchCriteria):
"""Jinja2-templated branch criteria (only supported criteria type for now)."""
criteria_type: Literal["jinja2_template"] = "jinja2_template"
async def evaluate(self, context: BranchEvaluationContext) -> bool:
# Build the template context explicitly to avoid surprises in templates.
template_data = context.build_template_data()
try:
template = jinja_sandbox_env.from_string(self.expression)
except Exception as exc:
raise FailedToFormatJinjaStyleParameter(
template=self.expression,
msg=str(exc),
) from exc
if settings.WORKFLOW_TEMPLATING_STRICTNESS == "strict":
if missing := get_missing_variables(self.expression, template_data):
raise MissingJinjaVariables(template=self.expression, variables=missing)
try:
rendered = template.render(template_data)
except Exception as exc:
raise FailedToFormatJinjaStyleParameter(
template=self.expression,
msg=str(exc),
) from exc
return bool(rendered)
class BranchCondition(BaseModel): class BranchCondition(BaseModel):
"""Represents a single conditional branch edge within a ConditionalBlock.""" """Represents a single conditional branch edge within a ConditionalBlock."""
criteria: BranchCriteria | None = None criteria: BranchCriteria | None = None
next_block_label: str | None = None next_block_label: str | None = None
description: str | None = None description: str | None = None
order: int = Field(ge=0)
is_default: bool = False is_default: bool = False
@model_validator(mode="after") @model_validator(mode="after")
def validate_condition(cls, condition: BranchCondition) -> BranchCondition: def validate_condition(cls, condition: BranchCondition) -> BranchCondition:
if isinstance(condition.criteria, dict):
condition.criteria = JinjaBranchCriteria(**condition.criteria)
if condition.criteria is None and not condition.is_default: if condition.criteria is None and not condition.is_default:
raise ValueError("Branches without criteria must be marked as default.") raise ValueError("Branches without criteria must be marked as default.")
if condition.criteria is not None and not isinstance(condition.criteria, JinjaBranchCriteria):
raise ValueError("Only Jinja2 branch criteria are supported in this version.")
if condition.criteria is not None and condition.is_default: if condition.criteria is not None and condition.is_default:
raise ValueError("Default branches may not define criteria.") raise ValueError("Default branches may not define criteria.")
return condition return condition
class ConditionalBlock(Block): class ConditionalBlock(Block):
"""Branching block that selects the next block label based on ordered conditions.""" """Branching block that selects the next block label based on list-ordered conditions."""
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error: # There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
# Parameter 1 of Literal[...] cannot be of type "Any" # Parameter 1 of Literal[...] cannot be of type "Any"
@@ -3876,15 +3973,10 @@ class ConditionalBlock(Block):
if not block.branches: if not block.branches:
raise ValueError("Conditional blocks require at least one branch.") raise ValueError("Conditional blocks require at least one branch.")
orders = [branch.order for branch in block.branches]
if len(orders) != len(set(orders)):
raise ValueError("Branch order must be unique within a conditional block.")
default_branches = [branch for branch in block.branches if branch.is_default] default_branches = [branch for branch in block.branches if branch.is_default]
if len(default_branches) > 1: if len(default_branches) > 1:
raise ValueError("Only one default branch is permitted per conditional block.") raise ValueError("Only one default branch is permitted per conditional block.")
block.branches = sorted(block.branches, key=lambda branch: branch.order)
return block return block
def get_all_parameters( def get_all_parameters(
@@ -3912,7 +4004,7 @@ class ConditionalBlock(Block):
@property @property
def ordered_branches(self) -> list[BranchCondition]: def ordered_branches(self) -> list[BranchCondition]:
"""Convenience accessor that returns branches sorted by order.""" """Convenience accessor that returns branches in author-specified list order."""
return list(self.branches) return list(self.branches)
def get_default_branch(self) -> BranchCondition | None: def get_default_branch(self) -> BranchCondition | None: