workflow DAG execution (#4222)
This commit is contained in:
@@ -183,6 +183,10 @@ class Block(BaseModel, abc.ABC):
|
||||
status: BlockStatus | None = None,
|
||||
workflow_run_block_id: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
executed_branch_id: str | None = None,
|
||||
executed_branch_expression: str | None = None,
|
||||
executed_branch_result: bool | None = None,
|
||||
executed_branch_next_block: str | None = None,
|
||||
) -> BlockResult:
|
||||
# TODO: update workflow run block status and failure reason
|
||||
if isinstance(output_parameter_value, str):
|
||||
@@ -195,6 +199,10 @@ class Block(BaseModel, abc.ABC):
|
||||
status=status,
|
||||
failure_reason=failure_reason,
|
||||
organization_id=organization_id,
|
||||
executed_branch_id=executed_branch_id,
|
||||
executed_branch_expression=executed_branch_expression,
|
||||
executed_branch_result=executed_branch_result,
|
||||
executed_branch_next_block=executed_branch_next_block,
|
||||
)
|
||||
return BlockResult(
|
||||
success=success,
|
||||
@@ -4110,6 +4118,19 @@ class JinjaBranchCriteria(BranchCriteria):
|
||||
return _evaluate_truthy_string(rendered)
|
||||
|
||||
|
||||
class PromptBranchCriteria(BranchCriteria):
|
||||
"""Natural language branch criteria."""
|
||||
|
||||
criteria_type: Literal["prompt"] = "prompt"
|
||||
|
||||
async def evaluate(self, context: BranchEvaluationContext) -> bool:
|
||||
# Natural language criteria are evaluated in batch by ConditionalBlock.execute.
|
||||
raise NotImplementedError("PromptBranchCriteria is evaluated in batch, not per-branch.")
|
||||
|
||||
def requires_llm(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class BranchCondition(BaseModel):
|
||||
"""Represents a single conditional branch edge within a ConditionalBlock."""
|
||||
|
||||
@@ -4120,16 +4141,34 @@ class BranchCondition(BaseModel):
|
||||
is_default: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_condition(cls, condition: BranchCondition) -> BranchCondition:
|
||||
if isinstance(condition.criteria, dict):
|
||||
condition.criteria = JinjaBranchCriteria(**condition.criteria)
|
||||
if condition.criteria is None and not condition.is_default:
|
||||
def validate_condition(cls, condition_obj: BranchCondition) -> BranchCondition:
|
||||
if isinstance(condition_obj.criteria, dict):
|
||||
criteria_type = condition_obj.criteria.get("criteria_type")
|
||||
if criteria_type is None:
|
||||
# Infer criteria type from expression format
|
||||
expression = condition_obj.criteria.get("expression", "")
|
||||
if expression.startswith("{{") and expression.endswith("}}"):
|
||||
criteria_type = "jinja2_template"
|
||||
else:
|
||||
criteria_type = "prompt"
|
||||
if criteria_type == "prompt":
|
||||
condition_obj.criteria = PromptBranchCriteria(**condition_obj.criteria)
|
||||
else:
|
||||
condition_obj.criteria = JinjaBranchCriteria(**condition_obj.criteria)
|
||||
if condition_obj.criteria is None and not condition_obj.is_default:
|
||||
raise ValueError("Branches without criteria must be marked as default.")
|
||||
if condition.criteria is not None and not isinstance(condition.criteria, JinjaBranchCriteria):
|
||||
raise ValueError("Only Jinja2 branch criteria are supported in this version.")
|
||||
if condition.criteria is not None and condition.is_default:
|
||||
if condition_obj.criteria is not None and condition_obj.is_default:
|
||||
raise ValueError("Default branches may not define criteria.")
|
||||
return condition
|
||||
if condition_obj.criteria and isinstance(condition_obj.criteria, BranchCriteria):
|
||||
expression = condition_obj.criteria.expression
|
||||
criteria_dict = condition_obj.criteria.model_dump()
|
||||
if expression and expression.startswith("{{") and expression.endswith("}}"):
|
||||
criteria_dict["criteria_type"] = "jinja2_template"
|
||||
condition_obj.criteria = JinjaBranchCriteria(**criteria_dict)
|
||||
else:
|
||||
criteria_dict["criteria_type"] = "prompt"
|
||||
condition_obj.criteria = PromptBranchCriteria(**criteria_dict)
|
||||
return condition_obj
|
||||
|
||||
|
||||
class ConditionalBlock(Block):
|
||||
@@ -4159,6 +4198,117 @@ class ConditionalBlock(Block):
|
||||
# BranchCriteria subclasses will surface their parameter dependencies once implemented.
|
||||
return []
|
||||
|
||||
async def _evaluate_prompt_branches(
|
||||
self,
|
||||
*,
|
||||
branches: list[BranchCondition],
|
||||
evaluation_context: BranchEvaluationContext,
|
||||
workflow_run_id: str,
|
||||
workflow_run_block_id: str,
|
||||
organization_id: str | None = None,
|
||||
) -> list[bool]:
|
||||
if organization_id is None:
|
||||
raise ValueError("organization_id is required to evaluate natural language branches")
|
||||
|
||||
workflow_run_context = evaluation_context.workflow_run_context
|
||||
branch_criteria_payload = [
|
||||
{"index": idx, "expression": branch.criteria.expression if branch.criteria else ""}
|
||||
for idx, branch in enumerate(branches)
|
||||
]
|
||||
|
||||
extraction_goal = prompt_engine.load_prompt(
|
||||
"conditional-prompt-branch-evaluation",
|
||||
branch_criteria=branch_criteria_payload,
|
||||
)
|
||||
|
||||
data_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"branch_results": {
|
||||
"type": "array",
|
||||
"description": "Boolean results for each natural language branch in order.",
|
||||
"items": {"type": "boolean"},
|
||||
}
|
||||
},
|
||||
"required": ["branch_results"],
|
||||
}
|
||||
|
||||
output_param = OutputParameter(
|
||||
output_parameter_id=str(uuid.uuid4()),
|
||||
key=f"prompt_branch_eval_{generate_random_string()}",
|
||||
workflow_id=self.output_parameter.workflow_id,
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
parameter_type=ParameterType.OUTPUT,
|
||||
description="Prompt branch evaluation result",
|
||||
)
|
||||
|
||||
extraction_block = ExtractionBlock(
|
||||
label=f"prompt_branch_eval_{generate_random_string()}",
|
||||
data_extraction_goal=extraction_goal,
|
||||
data_schema=data_schema,
|
||||
output_parameter=output_param,
|
||||
)
|
||||
|
||||
extraction_result = await extraction_block.execute(
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not extraction_result.success:
|
||||
raise ValueError(f"Prompt branch evaluation failed: {extraction_result.failure_reason}")
|
||||
|
||||
output_value = extraction_result.output_parameter_value
|
||||
if workflow_run_context:
|
||||
try:
|
||||
await extraction_block.record_output_parameter_value(
|
||||
workflow_run_context=workflow_run_context,
|
||||
workflow_run_id=workflow_run_id,
|
||||
value=output_value,
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
"Failed to record prompt branch evaluation output",
|
||||
workflow_run_id=workflow_run_id,
|
||||
block_label=self.label,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
extracted_info: Any | None = None
|
||||
if isinstance(output_value, dict):
|
||||
extracted_info = output_value.get("extracted_information")
|
||||
|
||||
if isinstance(extracted_info, list) and len(extracted_info) == 1:
|
||||
extracted_info = extracted_info[0]
|
||||
|
||||
if not isinstance(extracted_info, dict):
|
||||
raise ValueError("Prompt branch evaluation returned no extracted_information payload")
|
||||
|
||||
branch_results_raw = extracted_info.get("branch_results")
|
||||
if not isinstance(branch_results_raw, list):
|
||||
raise ValueError("Prompt branch evaluation did not return branch_results list")
|
||||
|
||||
branch_results: list[bool] = []
|
||||
for result in branch_results_raw:
|
||||
if isinstance(result, bool):
|
||||
branch_results.append(result)
|
||||
else:
|
||||
evaluated_result = _evaluate_truthy_string(str(result))
|
||||
LOG.warning(
|
||||
"Prompt branch evaluation returned non-boolean result",
|
||||
result=result,
|
||||
evaluated_result=evaluated_result,
|
||||
)
|
||||
branch_results.append(evaluated_result)
|
||||
|
||||
if len(branch_results) != len(branches):
|
||||
raise ValueError(
|
||||
f"Prompt branch evaluation returned {len(branch_results)} results for {len(branches)} branches"
|
||||
)
|
||||
|
||||
return branch_results
|
||||
|
||||
async def execute( # noqa: D401
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@@ -4181,9 +4331,59 @@ class ConditionalBlock(Block):
|
||||
matched_branch = None
|
||||
failure_reason: str | None = None
|
||||
|
||||
natural_language_branches = [
|
||||
branch for branch in self.ordered_branches if isinstance(branch.criteria, PromptBranchCriteria)
|
||||
]
|
||||
prompt_results_by_id: dict[str, bool] = {}
|
||||
if natural_language_branches:
|
||||
try:
|
||||
prompt_results = await self._evaluate_prompt_branches(
|
||||
branches=natural_language_branches,
|
||||
evaluation_context=evaluation_context,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
prompt_results_by_id = {
|
||||
branch.id: result for branch, result in zip(natural_language_branches, prompt_results, strict=False)
|
||||
}
|
||||
except Exception as exc:
|
||||
failure_reason = f"Failed to evaluate natural language branches: {str(exc)}"
|
||||
LOG.error(
|
||||
"Failed to evaluate natural language branches",
|
||||
block_label=self.label,
|
||||
error=str(exc),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
for idx, branch in enumerate(self.ordered_branches):
|
||||
if branch.criteria is None:
|
||||
continue
|
||||
|
||||
if branch.criteria.criteria_type == "prompt":
|
||||
if failure_reason:
|
||||
break
|
||||
prompt_result = prompt_results_by_id.get(branch.id)
|
||||
if prompt_result is None:
|
||||
failure_reason = "Missing result for natural language branch evaluation"
|
||||
LOG.error(
|
||||
"Missing prompt evaluation result",
|
||||
block_label=self.label,
|
||||
branch_index=idx,
|
||||
branch_id=branch.id,
|
||||
)
|
||||
break
|
||||
if prompt_result:
|
||||
matched_branch = branch
|
||||
LOG.info(
|
||||
"Conditional natural language branch matched",
|
||||
block_label=self.label,
|
||||
branch_index=idx,
|
||||
next_block_label=branch.next_block_label,
|
||||
)
|
||||
break
|
||||
continue
|
||||
|
||||
try:
|
||||
if await branch.criteria.evaluate(evaluation_context):
|
||||
matched_branch = branch
|
||||
@@ -4210,10 +4410,28 @@ class ConditionalBlock(Block):
|
||||
|
||||
matched_index = self.ordered_branches.index(matched_branch) if matched_branch in self.ordered_branches else None
|
||||
next_block_label = matched_branch.next_block_label if matched_branch else None
|
||||
executed_branch_id = matched_branch.id if matched_branch else None
|
||||
|
||||
# Extract execution details for frontend display
|
||||
executed_branch_expression: str | None = None
|
||||
executed_branch_result: bool | None = None
|
||||
executed_branch_next_block: str | None = None
|
||||
|
||||
if matched_branch:
|
||||
executed_branch_next_block = matched_branch.next_block_label
|
||||
if matched_branch.is_default:
|
||||
# Default/else branch - no expression to evaluate
|
||||
executed_branch_expression = None
|
||||
executed_branch_result = None
|
||||
elif matched_branch.criteria:
|
||||
# Regular condition branch - it matched
|
||||
executed_branch_expression = matched_branch.criteria.expression
|
||||
executed_branch_result = True
|
||||
|
||||
branch_metadata: BlockMetadata = {
|
||||
"branch_taken": next_block_label,
|
||||
"branch_index": matched_index,
|
||||
"branch_id": executed_branch_id,
|
||||
"branch_description": matched_branch.description if matched_branch else None,
|
||||
"criteria_type": matched_branch.criteria.criteria_type
|
||||
if matched_branch and matched_branch.criteria
|
||||
@@ -4258,6 +4476,10 @@ class ConditionalBlock(Block):
|
||||
status=status,
|
||||
workflow_run_block_id=workflow_run_block_id,
|
||||
organization_id=organization_id,
|
||||
executed_branch_id=executed_branch_id,
|
||||
executed_branch_expression=executed_branch_expression,
|
||||
executed_branch_result=executed_branch_result,
|
||||
executed_branch_next_block=executed_branch_next_block,
|
||||
)
|
||||
return block_result
|
||||
|
||||
@@ -4318,5 +4540,5 @@ BlockSubclasses = Union[
|
||||
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
|
||||
|
||||
|
||||
BranchCriteriaSubclasses = Union[JinjaBranchCriteria]
|
||||
BranchCriteriaSubclasses = Union[JinjaBranchCriteria, PromptBranchCriteria]
|
||||
BranchCriteriaTypeVar = Annotated[BranchCriteriaSubclasses, Field(discriminator="criteria_type")]
|
||||
|
||||
Reference in New Issue
Block a user