ConditionalBlock spec Update + Implementation for BranchCondition and BranchCriteria (#4120)
This commit is contained in:
@@ -3819,19 +3819,81 @@ class HttpRequestBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BranchEvaluationContext(BaseModel):
|
class BranchEvaluationContext:
|
||||||
"""Collection of runtime data that BranchCriteria evaluators can consume."""
|
"""Collection of runtime data that BranchCriteria evaluators can consume."""
|
||||||
|
|
||||||
workflow_parameters: dict[str, Any] = Field(default_factory=dict)
|
def __init__(
|
||||||
block_outputs: dict[str, Any] = Field(default_factory=dict)
|
self,
|
||||||
environment: dict[str, Any] | None = None
|
*,
|
||||||
llm_results: dict[str, Any] | None = None
|
workflow_run_context: WorkflowRunContext | None = None,
|
||||||
|
block_label: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.workflow_run_context = workflow_run_context
|
||||||
|
self.block_label = block_label
|
||||||
|
|
||||||
|
def build_template_data(self) -> dict[str, Any]:
|
||||||
|
"""Build Jinja template data mirroring block parameter rendering context."""
|
||||||
|
if self.workflow_run_context is None:
|
||||||
|
return {
|
||||||
|
"params": {},
|
||||||
|
"outputs": {},
|
||||||
|
"environment": {},
|
||||||
|
"env": {},
|
||||||
|
"llm": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = self.workflow_run_context
|
||||||
|
template_data = ctx.values.copy()
|
||||||
|
if ctx.include_secrets_in_templates:
|
||||||
|
template_data.update(ctx.secrets)
|
||||||
|
|
||||||
|
credential_params: list[tuple[str, dict[str, Any]]] = []
|
||||||
|
for key, value in template_data.items():
|
||||||
|
if isinstance(value, dict) and "context" in value and "username" in value and "password" in value:
|
||||||
|
credential_params.append((key, value))
|
||||||
|
|
||||||
|
for key, value in credential_params:
|
||||||
|
username_secret_id = value.get("username", "")
|
||||||
|
password_secret_id = value.get("password", "")
|
||||||
|
real_username = template_data.get(username_secret_id, "")
|
||||||
|
real_password = template_data.get(password_secret_id, "")
|
||||||
|
template_data[f"{key}_real_username"] = real_username
|
||||||
|
template_data[f"{key}_real_password"] = real_password
|
||||||
|
|
||||||
|
if self.block_label:
|
||||||
|
block_reference_data: dict[str, Any] = ctx.get_block_metadata(self.block_label)
|
||||||
|
if self.block_label in template_data:
|
||||||
|
current_value = template_data[self.block_label]
|
||||||
|
if isinstance(current_value, dict):
|
||||||
|
block_reference_data.update(current_value)
|
||||||
|
template_data[self.block_label] = block_reference_data
|
||||||
|
|
||||||
|
if "current_index" in block_reference_data:
|
||||||
|
template_data["current_index"] = block_reference_data["current_index"]
|
||||||
|
if "current_item" in block_reference_data:
|
||||||
|
template_data["current_item"] = block_reference_data["current_item"]
|
||||||
|
if "current_value" in block_reference_data:
|
||||||
|
template_data["current_value"] = block_reference_data["current_value"]
|
||||||
|
|
||||||
|
template_data.setdefault("workflow_title", ctx.workflow_title)
|
||||||
|
template_data.setdefault("workflow_id", ctx.workflow_id)
|
||||||
|
template_data.setdefault("workflow_permanent_id", ctx.workflow_permanent_id)
|
||||||
|
template_data.setdefault("workflow_run_id", ctx.workflow_run_id)
|
||||||
|
|
||||||
|
template_data.setdefault("params", template_data.get("params", {}))
|
||||||
|
template_data.setdefault("outputs", template_data.get("outputs", {}))
|
||||||
|
template_data.setdefault("environment", template_data.get("environment", {}))
|
||||||
|
template_data.setdefault("env", template_data.get("environment"))
|
||||||
|
template_data.setdefault("llm", template_data.get("llm", {}))
|
||||||
|
|
||||||
|
return template_data
|
||||||
|
|
||||||
|
|
||||||
class BranchCriteria(BaseModel, abc.ABC):
|
class BranchCriteria(BaseModel, abc.ABC):
|
||||||
"""Abstract interface describing how a branch condition should be evaluated."""
|
"""Abstract interface describing how a branch condition should be evaluated."""
|
||||||
|
|
||||||
criteria_type: str
|
criteria_type: str
|
||||||
|
expression: str
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -3844,26 +3906,61 @@ class BranchCriteria(BaseModel, abc.ABC):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class JinjaBranchCriteria(BranchCriteria):
|
||||||
|
"""Jinja2-templated branch criteria (only supported criteria type for now)."""
|
||||||
|
|
||||||
|
criteria_type: Literal["jinja2_template"] = "jinja2_template"
|
||||||
|
|
||||||
|
async def evaluate(self, context: BranchEvaluationContext) -> bool:
|
||||||
|
# Build the template context explicitly to avoid surprises in templates.
|
||||||
|
template_data = context.build_template_data()
|
||||||
|
|
||||||
|
try:
|
||||||
|
template = jinja_sandbox_env.from_string(self.expression)
|
||||||
|
except Exception as exc:
|
||||||
|
raise FailedToFormatJinjaStyleParameter(
|
||||||
|
template=self.expression,
|
||||||
|
msg=str(exc),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if settings.WORKFLOW_TEMPLATING_STRICTNESS == "strict":
|
||||||
|
if missing := get_missing_variables(self.expression, template_data):
|
||||||
|
raise MissingJinjaVariables(template=self.expression, variables=missing)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rendered = template.render(template_data)
|
||||||
|
except Exception as exc:
|
||||||
|
raise FailedToFormatJinjaStyleParameter(
|
||||||
|
template=self.expression,
|
||||||
|
msg=str(exc),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return bool(rendered)
|
||||||
|
|
||||||
|
|
||||||
class BranchCondition(BaseModel):
|
class BranchCondition(BaseModel):
|
||||||
"""Represents a single conditional branch edge within a ConditionalBlock."""
|
"""Represents a single conditional branch edge within a ConditionalBlock."""
|
||||||
|
|
||||||
criteria: BranchCriteria | None = None
|
criteria: BranchCriteria | None = None
|
||||||
next_block_label: str | None = None
|
next_block_label: str | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
order: int = Field(ge=0)
|
|
||||||
is_default: bool = False
|
is_default: bool = False
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_condition(cls, condition: BranchCondition) -> BranchCondition:
|
def validate_condition(cls, condition: BranchCondition) -> BranchCondition:
|
||||||
|
if isinstance(condition.criteria, dict):
|
||||||
|
condition.criteria = JinjaBranchCriteria(**condition.criteria)
|
||||||
if condition.criteria is None and not condition.is_default:
|
if condition.criteria is None and not condition.is_default:
|
||||||
raise ValueError("Branches without criteria must be marked as default.")
|
raise ValueError("Branches without criteria must be marked as default.")
|
||||||
|
if condition.criteria is not None and not isinstance(condition.criteria, JinjaBranchCriteria):
|
||||||
|
raise ValueError("Only Jinja2 branch criteria are supported in this version.")
|
||||||
if condition.criteria is not None and condition.is_default:
|
if condition.criteria is not None and condition.is_default:
|
||||||
raise ValueError("Default branches may not define criteria.")
|
raise ValueError("Default branches may not define criteria.")
|
||||||
return condition
|
return condition
|
||||||
|
|
||||||
|
|
||||||
class ConditionalBlock(Block):
|
class ConditionalBlock(Block):
|
||||||
"""Branching block that selects the next block label based on ordered conditions."""
|
"""Branching block that selects the next block label based on list-ordered conditions."""
|
||||||
|
|
||||||
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
||||||
# Parameter 1 of Literal[...] cannot be of type "Any"
|
# Parameter 1 of Literal[...] cannot be of type "Any"
|
||||||
@@ -3876,15 +3973,10 @@ class ConditionalBlock(Block):
|
|||||||
if not block.branches:
|
if not block.branches:
|
||||||
raise ValueError("Conditional blocks require at least one branch.")
|
raise ValueError("Conditional blocks require at least one branch.")
|
||||||
|
|
||||||
orders = [branch.order for branch in block.branches]
|
|
||||||
if len(orders) != len(set(orders)):
|
|
||||||
raise ValueError("Branch order must be unique within a conditional block.")
|
|
||||||
|
|
||||||
default_branches = [branch for branch in block.branches if branch.is_default]
|
default_branches = [branch for branch in block.branches if branch.is_default]
|
||||||
if len(default_branches) > 1:
|
if len(default_branches) > 1:
|
||||||
raise ValueError("Only one default branch is permitted per conditional block.")
|
raise ValueError("Only one default branch is permitted per conditional block.")
|
||||||
|
|
||||||
block.branches = sorted(block.branches, key=lambda branch: branch.order)
|
|
||||||
return block
|
return block
|
||||||
|
|
||||||
def get_all_parameters(
|
def get_all_parameters(
|
||||||
@@ -3912,7 +4004,7 @@ class ConditionalBlock(Block):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def ordered_branches(self) -> list[BranchCondition]:
|
def ordered_branches(self) -> list[BranchCondition]:
|
||||||
"""Convenience accessor that returns branches sorted by order."""
|
"""Convenience accessor that returns branches in author-specified list order."""
|
||||||
return list(self.branches)
|
return list(self.branches)
|
||||||
|
|
||||||
def get_default_branch(self) -> BranchCondition | None:
|
def get_default_branch(self) -> BranchCondition | None:
|
||||||
|
|||||||
Reference in New Issue
Block a user