Add BranchCriteriaTypeVar and ConditionalBlockYAML (#4173)
This commit is contained in:
@@ -3969,7 +3969,7 @@ class JinjaBranchCriteria(BranchCriteria):
|
|||||||
class BranchCondition(BaseModel):
|
class BranchCondition(BaseModel):
|
||||||
"""Represents a single conditional branch edge within a ConditionalBlock."""
|
"""Represents a single conditional branch edge within a ConditionalBlock."""
|
||||||
|
|
||||||
criteria: BranchCriteria | None = None
|
criteria: BranchCriteriaTypeVar | None = None
|
||||||
next_block_label: str | None = None
|
next_block_label: str | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
is_default: bool = False
|
is_default: bool = False
|
||||||
@@ -4085,3 +4085,7 @@ BlockSubclasses = Union[
|
|||||||
HttpRequestBlock,
|
HttpRequestBlock,
|
||||||
]
|
]
|
||||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||||
|
|
||||||
|
|
||||||
|
BranchCriteriaSubclasses = Union[JinjaBranchCriteria]
|
||||||
|
BranchCriteriaTypeVar = Annotated[BranchCriteriaSubclasses, Field(discriminator="criteria_type")]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType, WorkflowParameterType
|
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType, WorkflowParameterType
|
||||||
@@ -258,6 +258,44 @@ class ForLoopBlockYAML(BlockYAML):
|
|||||||
complete_if_empty: bool = False
|
complete_if_empty: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class BranchCriteriaYAML(BaseModel):
|
||||||
|
criteria_type: Literal["jinja2_template"] = "jinja2_template"
|
||||||
|
expression: str
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BranchConditionYAML(BaseModel):
|
||||||
|
criteria: BranchCriteriaYAML | None = None
|
||||||
|
next_block_label: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
is_default: bool = False
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_condition(cls, condition: "BranchConditionYAML") -> "BranchConditionYAML":
|
||||||
|
if condition.criteria is None and not condition.is_default:
|
||||||
|
raise ValueError("Branches without criteria must be marked as default.")
|
||||||
|
if condition.criteria is not None and condition.is_default:
|
||||||
|
raise ValueError("Default branches may not define criteria.")
|
||||||
|
return condition
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalBlockYAML(BlockYAML):
|
||||||
|
block_type: Literal[BlockType.CONDITIONAL] = BlockType.CONDITIONAL # type: ignore
|
||||||
|
|
||||||
|
branch_conditions: list[BranchConditionYAML] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_branches(cls, block: "ConditionalBlockYAML") -> "ConditionalBlockYAML":
|
||||||
|
if not block.branch_conditions:
|
||||||
|
raise ValueError("Conditional blocks require at least one branch.")
|
||||||
|
|
||||||
|
default_branches = [branch for branch in block.branch_conditions if branch.is_default]
|
||||||
|
if len(default_branches) > 1:
|
||||||
|
raise ValueError("Only one default branch is permitted per conditional block.")
|
||||||
|
|
||||||
|
return block
|
||||||
|
|
||||||
|
|
||||||
class CodeBlockYAML(BlockYAML):
|
class CodeBlockYAML(BlockYAML):
|
||||||
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
# There is a mypy bug with Literal. Without the type: ignore, mypy will raise an error:
|
||||||
# Parameter 1 of Literal[...] cannot be of type "Any"
|
# Parameter 1 of Literal[...] cannot be of type "Any"
|
||||||
@@ -538,6 +576,7 @@ BLOCK_YAML_SUBCLASSES = (
|
|||||||
| PDFParserBlockYAML
|
| PDFParserBlockYAML
|
||||||
| TaskV2BlockYAML
|
| TaskV2BlockYAML
|
||||||
| HttpRequestBlockYAML
|
| HttpRequestBlockYAML
|
||||||
|
| ConditionalBlockYAML
|
||||||
)
|
)
|
||||||
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]
|
BLOCK_YAML_TYPES = Annotated[BLOCK_YAML_SUBCLASSES, Field(discriminator="block_type")]
|
||||||
|
|
||||||
@@ -547,6 +586,20 @@ class WorkflowDefinitionYAML(BaseModel):
|
|||||||
parameters: list[PARAMETER_YAML_TYPES]
|
parameters: list[PARAMETER_YAML_TYPES]
|
||||||
blocks: list[BLOCK_YAML_TYPES]
|
blocks: list[BLOCK_YAML_TYPES]
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_unique_block_labels(cls, workflow: "WorkflowDefinitionYAML") -> "WorkflowDefinitionYAML":
|
||||||
|
labels = [block.label for block in workflow.blocks]
|
||||||
|
duplicates = [label for label in labels if labels.count(label) > 1]
|
||||||
|
|
||||||
|
if duplicates:
|
||||||
|
unique_duplicates = sorted(set(duplicates))
|
||||||
|
raise ValueError(
|
||||||
|
f"Block labels must be unique within a workflow. "
|
||||||
|
f"Found duplicate label(s): {', '.join(unique_duplicates)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCreateYAMLRequest(BaseModel):
|
class WorkflowCreateYAMLRequest(BaseModel):
|
||||||
title: str
|
title: str
|
||||||
|
|||||||
Reference in New Issue
Block a user