ConditionalBlock spec Update + Implementation for BranchCondition and BranchCriteria (#4120)
This commit is contained in:
@@ -3819,19 +3819,81 @@ class HttpRequestBlock(Block):
|
||||
)
|
||||
|
||||
|
||||
class BranchEvaluationContext(BaseModel):
|
||||
class BranchEvaluationContext:
|
||||
"""Collection of runtime data that BranchCriteria evaluators can consume."""
|
||||
|
||||
workflow_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
block_outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
environment: dict[str, Any] | None = None
|
||||
llm_results: dict[str, Any] | None = None
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workflow_run_context: WorkflowRunContext | None = None,
|
||||
block_label: str | None = None,
|
||||
) -> None:
|
||||
self.workflow_run_context = workflow_run_context
|
||||
self.block_label = block_label
|
||||
|
||||
def build_template_data(self) -> dict[str, Any]:
|
||||
"""Build Jinja template data mirroring block parameter rendering context."""
|
||||
if self.workflow_run_context is None:
|
||||
return {
|
||||
"params": {},
|
||||
"outputs": {},
|
||||
"environment": {},
|
||||
"env": {},
|
||||
"llm": {},
|
||||
}
|
||||
|
||||
ctx = self.workflow_run_context
|
||||
template_data = ctx.values.copy()
|
||||
if ctx.include_secrets_in_templates:
|
||||
template_data.update(ctx.secrets)
|
||||
|
||||
credential_params: list[tuple[str, dict[str, Any]]] = []
|
||||
for key, value in template_data.items():
|
||||
if isinstance(value, dict) and "context" in value and "username" in value and "password" in value:
|
||||
credential_params.append((key, value))
|
||||
|
||||
for key, value in credential_params:
|
||||
username_secret_id = value.get("username", "")
|
||||
password_secret_id = value.get("password", "")
|
||||
real_username = template_data.get(username_secret_id, "")
|
||||
real_password = template_data.get(password_secret_id, "")
|
||||
template_data[f"{key}_real_username"] = real_username
|
||||
template_data[f"{key}_real_password"] = real_password
|
||||
|
||||
if self.block_label:
|
||||
block_reference_data: dict[str, Any] = ctx.get_block_metadata(self.block_label)
|
||||
if self.block_label in template_data:
|
||||
current_value = template_data[self.block_label]
|
||||
if isinstance(current_value, dict):
|
||||
block_reference_data.update(current_value)
|
||||
template_data[self.block_label] = block_reference_data
|
||||
|
||||
if "current_index" in block_reference_data:
|
||||
template_data["current_index"] = block_reference_data["current_index"]
|
||||
if "current_item" in block_reference_data:
|
||||
template_data["current_item"] = block_reference_data["current_item"]
|
||||
if "current_value" in block_reference_data:
|
||||
template_data["current_value"] = block_reference_data["current_value"]
|
||||
|
||||
template_data.setdefault("workflow_title", ctx.workflow_title)
|
||||
template_data.setdefault("workflow_id", ctx.workflow_id)
|
||||
template_data.setdefault("workflow_permanent_id", ctx.workflow_permanent_id)
|
||||
template_data.setdefault("workflow_run_id", ctx.workflow_run_id)
|
||||
|
||||
template_data.setdefault("params", template_data.get("params", {}))
|
||||
template_data.setdefault("outputs", template_data.get("outputs", {}))
|
||||
template_data.setdefault("environment", template_data.get("environment", {}))
|
||||
template_data.setdefault("env", template_data.get("environment"))
|
||||
template_data.setdefault("llm", template_data.get("llm", {}))
|
||||
|
||||
return template_data
|
||||
|
||||
|
||||
class BranchCriteria(BaseModel, abc.ABC):
|
||||
"""Abstract interface describing how a branch condition should be evaluated."""
|
||||
|
||||
criteria_type: str
|
||||
expression: str
|
||||
description: str | None = None
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -3844,26 +3906,61 @@ class BranchCriteria(BaseModel, abc.ABC):
|
||||
return False
|
||||
|
||||
|
||||
class JinjaBranchCriteria(BranchCriteria):
|
||||
"""Jinja2-templated branch criteria (only supported criteria type for now)."""
|
||||
|
||||
criteria_type: Literal["jinja2_template"] = "jinja2_template"
|
||||
|
||||
async def evaluate(self, context: BranchEvaluationContext) -> bool:
|
||||
# Build the template context explicitly to avoid surprises in templates.
|
||||
template_data = context.build_template_data()
|
||||
|
||||
try:
|
||||
template = jinja_sandbox_env.from_string(self.expression)
|
||||
except Exception as exc:
|
||||
raise FailedToFormatJinjaStyleParameter(
|
||||
template=self.expression,
|
||||
msg=str(exc),
|
||||
) from exc
|
||||
|
||||
if settings.WORKFLOW_TEMPLATING_STRICTNESS == "strict":
|
||||
if missing := get_missing_variables(self.expression, template_data):
|
||||
raise MissingJinjaVariables(template=self.expression, variables=missing)
|
||||
|
||||
try:
|
||||
rendered = template.render(template_data)
|
||||
except Exception as exc:
|
||||
raise FailedToFormatJinjaStyleParameter(
|
||||
template=self.expression,
|
||||
msg=str(exc),
|
||||
) from exc
|
||||
|
||||
return bool(rendered)
|
||||
|
||||
|
||||
class BranchCondition(BaseModel):
|
||||
"""Represents a single conditional branch edge within a ConditionalBlock."""
|
||||
|
||||
criteria: BranchCriteria | None = None
|
||||
next_block_label: str | None = None
|
||||
description: str | None = None
|
||||
order: int = Field(ge=0)
|
||||
is_default: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_condition(cls, condition: BranchCondition) -> BranchCondition:
|
||||
if isinstance(condition.criteria, dict):
|
||||
condition.criteria = JinjaBranchCriteria(**condition.criteria)
|
||||
if condition.criteria is None and not condition.is_default:
|
||||
raise ValueError("Branches without criteria must be marked as default.")
|
||||
if condition.criteria is not None and not isinstance(condition.criteria, JinjaBranchCriteria):
|
||||
raise ValueError("Only Jinja2 branch criteria are supported in this version.")
|
||||
if condition.criteria is not None and condition.is_default:
|
||||
raise ValueError("Default branches may not define criteria.")
|
||||
return condition
|
||||
|
||||
|
||||
class ConditionalBlock(Block):
|
||||
"""Branching block that selects the next block label based on ordered conditions."""
|
||||
"""Branching block that selects the next block label based on list-ordered conditions."""
|
||||
|
||||
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
||||
# Parameter 1 of Literal[...] cannot be of type "Any"
|
||||
@@ -3876,15 +3973,10 @@ class ConditionalBlock(Block):
|
||||
if not block.branches:
|
||||
raise ValueError("Conditional blocks require at least one branch.")
|
||||
|
||||
orders = [branch.order for branch in block.branches]
|
||||
if len(orders) != len(set(orders)):
|
||||
raise ValueError("Branch order must be unique within a conditional block.")
|
||||
|
||||
default_branches = [branch for branch in block.branches if branch.is_default]
|
||||
if len(default_branches) > 1:
|
||||
raise ValueError("Only one default branch is permitted per conditional block.")
|
||||
|
||||
block.branches = sorted(block.branches, key=lambda branch: branch.order)
|
||||
return block
|
||||
|
||||
def get_all_parameters(
|
||||
@@ -3912,7 +4004,7 @@ class ConditionalBlock(Block):
|
||||
|
||||
@property
|
||||
def ordered_branches(self) -> list[BranchCondition]:
|
||||
"""Convenience accessor that returns branches sorted by order."""
|
||||
"""Convenience accessor that returns branches in author-specified list order."""
|
||||
return list(self.branches)
|
||||
|
||||
def get_default_branch(self) -> BranchCondition | None:
|
||||
|
||||
Reference in New Issue
Block a user