From ea40b64fdc9e802563f3829da45a2d8b27c2eeda Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 27 Nov 2025 11:52:37 -0800 Subject: [PATCH] ConditionalBlock spec Update + Implementation for BranchCondition and BranchCriteria (#4120) --- skyvern/forge/sdk/workflow/models/block.py | 118 ++++++++++++++++++--- 1 file changed, 105 insertions(+), 13 deletions(-) diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index d290e4c8..3d49f3ba 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -3819,19 +3819,81 @@ class HttpRequestBlock(Block): ) -class BranchEvaluationContext(BaseModel): +class BranchEvaluationContext: """Collection of runtime data that BranchCriteria evaluators can consume.""" - workflow_parameters: dict[str, Any] = Field(default_factory=dict) - block_outputs: dict[str, Any] = Field(default_factory=dict) - environment: dict[str, Any] | None = None - llm_results: dict[str, Any] | None = None + def __init__( + self, + *, + workflow_run_context: WorkflowRunContext | None = None, + block_label: str | None = None, + ) -> None: + self.workflow_run_context = workflow_run_context + self.block_label = block_label + + def build_template_data(self) -> dict[str, Any]: + """Build Jinja template data mirroring block parameter rendering context.""" + if self.workflow_run_context is None: + return { + "params": {}, + "outputs": {}, + "environment": {}, + "env": {}, + "llm": {}, + } + + ctx = self.workflow_run_context + template_data = ctx.values.copy() + if ctx.include_secrets_in_templates: + template_data.update(ctx.secrets) + + credential_params: list[tuple[str, dict[str, Any]]] = [] + for key, value in template_data.items(): + if isinstance(value, dict) and "context" in value and "username" in value and "password" in value: + credential_params.append((key, value)) + + for key, value in credential_params: + username_secret_id = value.get("username", "") + password_secret_id = value.get("password", "") + real_username = template_data.get(username_secret_id, "") + real_password = template_data.get(password_secret_id, "") + template_data[f"{key}_real_username"] = real_username + template_data[f"{key}_real_password"] = real_password + + if self.block_label: + block_reference_data: dict[str, Any] = ctx.get_block_metadata(self.block_label) + if self.block_label in template_data: + current_value = template_data[self.block_label] + if isinstance(current_value, dict): + block_reference_data.update(current_value) + template_data[self.block_label] = block_reference_data + + if "current_index" in block_reference_data: + template_data["current_index"] = block_reference_data["current_index"] + if "current_item" in block_reference_data: + template_data["current_item"] = block_reference_data["current_item"] + if "current_value" in block_reference_data: + template_data["current_value"] = block_reference_data["current_value"] + + template_data.setdefault("workflow_title", ctx.workflow_title) + template_data.setdefault("workflow_id", ctx.workflow_id) + template_data.setdefault("workflow_permanent_id", ctx.workflow_permanent_id) + template_data.setdefault("workflow_run_id", ctx.workflow_run_id) + + template_data.setdefault("params", template_data.get("params", {})) + template_data.setdefault("outputs", template_data.get("outputs", {})) + template_data.setdefault("environment", template_data.get("environment", {})) + template_data.setdefault("env", template_data.get("environment")) + template_data.setdefault("llm", template_data.get("llm", {})) + + return template_data class BranchCriteria(BaseModel, abc.ABC): """Abstract interface describing how a branch condition should be evaluated.""" criteria_type: str + expression: str description: str | None = None @abc.abstractmethod @@ -3844,26 +3906,61 @@ class BranchCriteria(BaseModel, abc.ABC): return False +class JinjaBranchCriteria(BranchCriteria): + """Jinja2-templated branch criteria (only supported criteria type for now).""" + + criteria_type: Literal["jinja2_template"] = "jinja2_template" + + async def evaluate(self, context: BranchEvaluationContext) -> bool: + # Build the template context explicitly to avoid surprises in templates. + template_data = context.build_template_data() + + try: + template = jinja_sandbox_env.from_string(self.expression) + except Exception as exc: + raise FailedToFormatJinjaStyleParameter( + template=self.expression, + msg=str(exc), + ) from exc + + if settings.WORKFLOW_TEMPLATING_STRICTNESS == "strict": + if missing := get_missing_variables(self.expression, template_data): + raise MissingJinjaVariables(template=self.expression, variables=missing) + + try: + rendered = template.render(template_data) + except Exception as exc: + raise FailedToFormatJinjaStyleParameter( + template=self.expression, + msg=str(exc), + ) from exc + + return bool(rendered) + + class BranchCondition(BaseModel): """Represents a single conditional branch edge within a ConditionalBlock.""" criteria: BranchCriteria | None = None next_block_label: str | None = None description: str | None = None - order: int = Field(ge=0) is_default: bool = False @model_validator(mode="after") def validate_condition(cls, condition: BranchCondition) -> BranchCondition: + if isinstance(condition.criteria, dict): + condition.criteria = JinjaBranchCriteria(**condition.criteria) if condition.criteria is None and not condition.is_default: raise ValueError("Branches without criteria must be marked as default.") + if condition.criteria is not None and not isinstance(condition.criteria, JinjaBranchCriteria): + raise ValueError("Only Jinja2 branch criteria are supported in this version.") if condition.criteria is not None and condition.is_default: raise ValueError("Default branches may not define criteria.") return condition class ConditionalBlock(Block): - """Branching block that selects the next block label based on ordered conditions.""" + """Branching block that selects the next block label based on list-ordered conditions.""" # There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error: # Parameter 1 of Literal[...] cannot be of type "Any" @@ -3876,15 +3973,10 @@ class ConditionalBlock(Block): if not block.branches: raise ValueError("Conditional blocks require at least one branch.") - orders = [branch.order for branch in block.branches] - if len(orders) != len(set(orders)): - raise ValueError("Branch order must be unique within a conditional block.") - default_branches = [branch for branch in block.branches if branch.is_default] if len(default_branches) > 1: raise ValueError("Only one default branch is permitted per conditional block.") - block.branches = sorted(block.branches, key=lambda branch: branch.order) return block def get_all_parameters( @@ -3912,7 +4004,7 @@ class ConditionalBlock(Block): @property def ordered_branches(self) -> list[BranchCondition]: - """Convenience accessor that returns branches sorted by order.""" + """Convenience accessor that returns branches in author-specified list order.""" return list(self.branches) def get_default_branch(self) -> BranchCondition | None: