workflow DAG execution (#4222)

This commit is contained in:
Shuchang Zheng
2025-12-07 12:37:00 -08:00
committed by GitHub
parent 45307cc2ba
commit 753a36ac2e
10 changed files with 332 additions and 21 deletions

View File

@@ -0,0 +1,37 @@
"""update migration script
Revision ID: 135afee6e7bc
Revises: 152354699b93
Create Date: 2025-12-07 20:27:07.352740+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "135afee6e7bc"
down_revision: Union[str, None] = "152354699b93"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("workflow_run_blocks", sa.Column("executed_branch_id", sa.String(), nullable=True))
op.add_column("workflow_run_blocks", sa.Column("executed_branch_expression", sa.String(), nullable=True))
op.add_column("workflow_run_blocks", sa.Column("executed_branch_result", sa.Boolean(), nullable=True))
op.add_column("workflow_run_blocks", sa.Column("executed_branch_next_block", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("workflow_run_blocks", "executed_branch_next_block")
op.drop_column("workflow_run_blocks", "executed_branch_result")
op.drop_column("workflow_run_blocks", "executed_branch_expression")
op.drop_column("workflow_run_blocks", "executed_branch_id")
# ### end Alembic commands ###

View File

@@ -7,7 +7,7 @@ from ..core.pydantic_utilities import IS_PYDANTIC_V2, UniversalBaseModel
class BranchCriteriaYaml(UniversalBaseModel):
criteria_type: typing.Optional[typing.Literal["jinja2_template"]] = None
criteria_type: typing.Optional[typing.Literal["jinja2_template", "prompt"]] = None
expression: str
description: typing.Optional[str] = None

View File

@@ -0,0 +1,11 @@
You are evaluating conditional branches for a workflow. Return the results to tell me whether each natural language criterion is satisfied.
Criteria (order matters; align outputs to these indices):
{% for criterion in branch_criteria -%}
- {{ criterion.index }}: {{ criterion.expression }}
{% endfor %}
Respond with JSON exactly in this shape:
{
"branch_results": [true | false per criterion, in order]
}

View File

@@ -3911,6 +3911,11 @@ class AgentDB:
instructions: str | None = None,
positive_descriptor: str | None = None,
negative_descriptor: str | None = None,
# conditional block
executed_branch_id: str | None = None,
executed_branch_expression: str | None = None,
executed_branch_result: bool | None = None,
executed_branch_next_block: str | None = None,
) -> WorkflowRunBlock:
async with self.Session() as session:
workflow_run_block = (
@@ -3977,6 +3982,15 @@ class AgentDB:
workflow_run_block.positive_descriptor = positive_descriptor
if negative_descriptor:
workflow_run_block.negative_descriptor = negative_descriptor
# conditional block fields
if executed_branch_id:
workflow_run_block.executed_branch_id = executed_branch_id
if executed_branch_expression is not None:
workflow_run_block.executed_branch_expression = executed_branch_expression
if executed_branch_result is not None:
workflow_run_block.executed_branch_result = executed_branch_result
if executed_branch_next_block is not None:
workflow_run_block.executed_branch_next_block = executed_branch_next_block
await session.commit()
await session.refresh(workflow_run_block)
else:

View File

@@ -714,6 +714,12 @@ class WorkflowRunBlockModel(Base):
positive_descriptor = Column(String, nullable=True)
negative_descriptor = Column(String, nullable=True)
# conditional block
executed_branch_id = Column(String, nullable=True)
executed_branch_expression = Column(String, nullable=True)
executed_branch_result = Column(Boolean, nullable=True)
executed_branch_next_block = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)

View File

@@ -579,6 +579,10 @@ def convert_to_workflow_run_block(
instructions=workflow_run_block_model.instructions,
positive_descriptor=workflow_run_block_model.positive_descriptor,
negative_descriptor=workflow_run_block_model.negative_descriptor,
executed_branch_id=workflow_run_block_model.executed_branch_id,
executed_branch_expression=workflow_run_block_model.executed_branch_expression,
executed_branch_result=workflow_run_block_model.executed_branch_result,
executed_branch_next_block=workflow_run_block_model.executed_branch_next_block,
)
if task:
if task.finished_at and task.started_at:

View File

@@ -58,6 +58,12 @@ class WorkflowRunBlock(BaseModel):
positive_descriptor: str | None = None
negative_descriptor: str | None = None
# conditional block
executed_branch_id: str | None = None
executed_branch_expression: str | None = None
executed_branch_result: bool | None = None
executed_branch_next_block: str | None = None
class WorkflowRunTimelineType(StrEnum):
thought = "thought"

View File

@@ -183,6 +183,10 @@ class Block(BaseModel, abc.ABC):
status: BlockStatus | None = None,
workflow_run_block_id: str | None = None,
organization_id: str | None = None,
executed_branch_id: str | None = None,
executed_branch_expression: str | None = None,
executed_branch_result: bool | None = None,
executed_branch_next_block: str | None = None,
) -> BlockResult:
# TODO: update workflow run block status and failure reason
if isinstance(output_parameter_value, str):
@@ -195,6 +199,10 @@ class Block(BaseModel, abc.ABC):
status=status,
failure_reason=failure_reason,
organization_id=organization_id,
executed_branch_id=executed_branch_id,
executed_branch_expression=executed_branch_expression,
executed_branch_result=executed_branch_result,
executed_branch_next_block=executed_branch_next_block,
)
return BlockResult(
success=success,
@@ -4110,6 +4118,19 @@ class JinjaBranchCriteria(BranchCriteria):
return _evaluate_truthy_string(rendered)
class PromptBranchCriteria(BranchCriteria):
"""Natural language branch criteria."""
criteria_type: Literal["prompt"] = "prompt"
async def evaluate(self, context: BranchEvaluationContext) -> bool:
# Natural language criteria are evaluated in batch by ConditionalBlock.execute.
raise NotImplementedError("PromptBranchCriteria is evaluated in batch, not per-branch.")
def requires_llm(self) -> bool:
return True
class BranchCondition(BaseModel):
"""Represents a single conditional branch edge within a ConditionalBlock."""
@@ -4120,16 +4141,34 @@ class BranchCondition(BaseModel):
is_default: bool = False
@model_validator(mode="after")
def validate_condition(cls, condition: BranchCondition) -> BranchCondition:
if isinstance(condition.criteria, dict):
condition.criteria = JinjaBranchCriteria(**condition.criteria)
if condition.criteria is None and not condition.is_default:
def validate_condition(cls, condition_obj: BranchCondition) -> BranchCondition:
if isinstance(condition_obj.criteria, dict):
criteria_type = condition_obj.criteria.get("criteria_type")
if criteria_type is None:
# Infer criteria type from expression format
expression = condition_obj.criteria.get("expression", "")
if expression.startswith("{{") and expression.endswith("}}"):
criteria_type = "jinja2_template"
else:
criteria_type = "prompt"
if criteria_type == "prompt":
condition_obj.criteria = PromptBranchCriteria(**condition_obj.criteria)
else:
condition_obj.criteria = JinjaBranchCriteria(**condition_obj.criteria)
if condition_obj.criteria is None and not condition_obj.is_default:
raise ValueError("Branches without criteria must be marked as default.")
if condition.criteria is not None and not isinstance(condition.criteria, JinjaBranchCriteria):
raise ValueError("Only Jinja2 branch criteria are supported in this version.")
if condition.criteria is not None and condition.is_default:
if condition_obj.criteria is not None and condition_obj.is_default:
raise ValueError("Default branches may not define criteria.")
return condition
if condition_obj.criteria and isinstance(condition_obj.criteria, BranchCriteria):
expression = condition_obj.criteria.expression
criteria_dict = condition_obj.criteria.model_dump()
if expression and expression.startswith("{{") and expression.endswith("}}"):
criteria_dict["criteria_type"] = "jinja2_template"
condition_obj.criteria = JinjaBranchCriteria(**criteria_dict)
else:
criteria_dict["criteria_type"] = "prompt"
condition_obj.criteria = PromptBranchCriteria(**criteria_dict)
return condition_obj
class ConditionalBlock(Block):
@@ -4159,6 +4198,117 @@ class ConditionalBlock(Block):
# BranchCriteria subclasses will surface their parameter dependencies once implemented.
return []
async def _evaluate_prompt_branches(
self,
*,
branches: list[BranchCondition],
evaluation_context: BranchEvaluationContext,
workflow_run_id: str,
workflow_run_block_id: str,
organization_id: str | None = None,
) -> list[bool]:
if organization_id is None:
raise ValueError("organization_id is required to evaluate natural language branches")
workflow_run_context = evaluation_context.workflow_run_context
branch_criteria_payload = [
{"index": idx, "expression": branch.criteria.expression if branch.criteria else ""}
for idx, branch in enumerate(branches)
]
extraction_goal = prompt_engine.load_prompt(
"conditional-prompt-branch-evaluation",
branch_criteria=branch_criteria_payload,
)
data_schema = {
"type": "object",
"properties": {
"branch_results": {
"type": "array",
"description": "Boolean results for each natural language branch in order.",
"items": {"type": "boolean"},
}
},
"required": ["branch_results"],
}
output_param = OutputParameter(
output_parameter_id=str(uuid.uuid4()),
key=f"prompt_branch_eval_{generate_random_string()}",
workflow_id=self.output_parameter.workflow_id,
created_at=datetime.now(),
modified_at=datetime.now(),
parameter_type=ParameterType.OUTPUT,
description="Prompt branch evaluation result",
)
extraction_block = ExtractionBlock(
label=f"prompt_branch_eval_{generate_random_string()}",
data_extraction_goal=extraction_goal,
data_schema=data_schema,
output_parameter=output_param,
)
extraction_result = await extraction_block.execute(
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
if not extraction_result.success:
raise ValueError(f"Prompt branch evaluation failed: {extraction_result.failure_reason}")
output_value = extraction_result.output_parameter_value
if workflow_run_context:
try:
await extraction_block.record_output_parameter_value(
workflow_run_context=workflow_run_context,
workflow_run_id=workflow_run_id,
value=output_value,
)
except Exception:
LOG.warning(
"Failed to record prompt branch evaluation output",
workflow_run_id=workflow_run_id,
block_label=self.label,
exc_info=True,
)
extracted_info: Any | None = None
if isinstance(output_value, dict):
extracted_info = output_value.get("extracted_information")
if isinstance(extracted_info, list) and len(extracted_info) == 1:
extracted_info = extracted_info[0]
if not isinstance(extracted_info, dict):
raise ValueError("Prompt branch evaluation returned no extracted_information payload")
branch_results_raw = extracted_info.get("branch_results")
if not isinstance(branch_results_raw, list):
raise ValueError("Prompt branch evaluation did not return branch_results list")
branch_results: list[bool] = []
for result in branch_results_raw:
if isinstance(result, bool):
branch_results.append(result)
else:
evaluated_result = _evaluate_truthy_string(str(result))
LOG.warning(
"Prompt branch evaluation returned non-boolean result",
result=result,
evaluated_result=evaluated_result,
)
branch_results.append(evaluated_result)
if len(branch_results) != len(branches):
raise ValueError(
f"Prompt branch evaluation returned {len(branch_results)} results for {len(branches)} branches"
)
return branch_results
async def execute( # noqa: D401
self,
workflow_run_id: str,
@@ -4181,9 +4331,59 @@ class ConditionalBlock(Block):
matched_branch = None
failure_reason: str | None = None
natural_language_branches = [
branch for branch in self.ordered_branches if isinstance(branch.criteria, PromptBranchCriteria)
]
prompt_results_by_id: dict[str, bool] = {}
if natural_language_branches:
try:
prompt_results = await self._evaluate_prompt_branches(
branches=natural_language_branches,
evaluation_context=evaluation_context,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
)
prompt_results_by_id = {
branch.id: result for branch, result in zip(natural_language_branches, prompt_results, strict=False)
}
except Exception as exc:
failure_reason = f"Failed to evaluate natural language branches: {str(exc)}"
LOG.error(
"Failed to evaluate natural language branches",
block_label=self.label,
error=str(exc),
exc_info=True,
)
for idx, branch in enumerate(self.ordered_branches):
if branch.criteria is None:
continue
if branch.criteria.criteria_type == "prompt":
if failure_reason:
break
prompt_result = prompt_results_by_id.get(branch.id)
if prompt_result is None:
failure_reason = "Missing result for natural language branch evaluation"
LOG.error(
"Missing prompt evaluation result",
block_label=self.label,
branch_index=idx,
branch_id=branch.id,
)
break
if prompt_result:
matched_branch = branch
LOG.info(
"Conditional natural language branch matched",
block_label=self.label,
branch_index=idx,
next_block_label=branch.next_block_label,
)
break
continue
try:
if await branch.criteria.evaluate(evaluation_context):
matched_branch = branch
@@ -4210,10 +4410,28 @@ class ConditionalBlock(Block):
matched_index = self.ordered_branches.index(matched_branch) if matched_branch in self.ordered_branches else None
next_block_label = matched_branch.next_block_label if matched_branch else None
executed_branch_id = matched_branch.id if matched_branch else None
# Extract execution details for frontend display
executed_branch_expression: str | None = None
executed_branch_result: bool | None = None
executed_branch_next_block: str | None = None
if matched_branch:
executed_branch_next_block = matched_branch.next_block_label
if matched_branch.is_default:
# Default/else branch - no expression to evaluate
executed_branch_expression = None
executed_branch_result = None
elif matched_branch.criteria:
# Regular condition branch - it matched
executed_branch_expression = matched_branch.criteria.expression
executed_branch_result = True
branch_metadata: BlockMetadata = {
"branch_taken": next_block_label,
"branch_index": matched_index,
"branch_id": executed_branch_id,
"branch_description": matched_branch.description if matched_branch else None,
"criteria_type": matched_branch.criteria.criteria_type
if matched_branch and matched_branch.criteria
@@ -4258,6 +4476,10 @@ class ConditionalBlock(Block):
status=status,
workflow_run_block_id=workflow_run_block_id,
organization_id=organization_id,
executed_branch_id=executed_branch_id,
executed_branch_expression=executed_branch_expression,
executed_branch_result=executed_branch_result,
executed_branch_next_block=executed_branch_next_block,
)
return block_result
@@ -4318,5 +4540,5 @@ BlockSubclasses = Union[
BlockTypeVar = Annotated[BlockSubclasses, Field(discriminator="block_type")]
BranchCriteriaSubclasses = Union[JinjaBranchCriteria]
BranchCriteriaSubclasses = Union[JinjaBranchCriteria, PromptBranchCriteria]
BranchCriteriaTypeVar = Annotated[BranchCriteriaSubclasses, Field(discriminator="criteria_type")]

View File

@@ -75,6 +75,7 @@ from skyvern.forge.sdk.workflow.models.block import (
LoginBlock,
NavigationBlock,
PDFParserBlock,
PromptBranchCriteria,
SendEmailBlock,
TaskBlock,
TaskV2Block,
@@ -946,6 +947,11 @@ class WorkflowService:
while current_label:
block = label_to_block.get(current_label)
if not block:
LOG.error(
"Unable to find block with label in workflow graph",
workflow_run_id=workflow_run.workflow_run_id,
current_label=current_label,
)
workflow_run = await self.mark_workflow_run_as_failed(
workflow_run_id=workflow_run.workflow_run_id,
failure_reason=f"Unable to find block with label {current_label}",
@@ -977,7 +983,7 @@ class WorkflowService:
break
next_label = None
if isinstance(block, ConditionalBlock):
if block.block_type == BlockType.CONDITIONAL:
next_label = (branch_metadata or {}).get("next_block_label")
else:
next_label = default_next_map.get(block.label)
@@ -3345,15 +3351,20 @@ class WorkflowService:
elif block_yaml.block_type == BlockType.CONDITIONAL:
branch_conditions = []
for branch in block_yaml.branch_conditions:
branch_criteria = (
JinjaBranchCriteria(
criteria_type=branch.criteria.criteria_type,
expression=branch.criteria.expression,
description=branch.criteria.description,
)
if branch.criteria
else None
)
branch_criteria = None
if branch.criteria:
if branch.criteria.criteria_type == "prompt":
branch_criteria = PromptBranchCriteria(
criteria_type=branch.criteria.criteria_type,
expression=branch.criteria.expression,
description=branch.criteria.description,
)
else:
branch_criteria = JinjaBranchCriteria(
criteria_type=branch.criteria.criteria_type,
expression=branch.criteria.expression,
description=branch.criteria.description,
)
branch_conditions.append(
BranchCondition(

View File

@@ -262,7 +262,7 @@ class ForLoopBlockYAML(BlockYAML):
class BranchCriteriaYAML(BaseModel):
criteria_type: Literal["jinja2_template"] = "jinja2_template"
criteria_type: Literal["jinja2_template", "prompt"] = "jinja2_template"
expression: str
description: str | None = None