This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
import base64
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from cachetools import TTLCache
|
||||
@@ -14,11 +12,9 @@ from skyvern.forge.sdk.core import skyvern_context
|
||||
from skyvern.forge.sdk.workflow.models.block import get_all_blocks
|
||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
|
||||
from skyvern.schemas.scripts import FileEncoding, Script, ScriptFileCreate, ScriptStatus
|
||||
from skyvern.schemas.workflows import BlockType
|
||||
from skyvern.services import script_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
jinja_sandbox_env = SandboxedEnvironment()
|
||||
|
||||
@@ -26,55 +22,6 @@ jinja_sandbox_env = SandboxedEnvironment()
|
||||
_workflow_script_cache: TTLCache[tuple, "Script"] = TTLCache(maxsize=128, ttl=60 * 60)
|
||||
|
||||
|
||||
def get_downstream_blocks(
|
||||
updated_labels: set[str],
|
||||
blocks: list["BlockTypeVar"],
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all blocks that are downstream of any updated block in the DAG.
|
||||
|
||||
Uses BFS to find all blocks reachable from the updated blocks by following
|
||||
next_block_label edges. This is used to invalidate cached blocks that may
|
||||
depend on data from updated upstream blocks.
|
||||
|
||||
Args:
|
||||
updated_labels: Set of block labels that have been updated
|
||||
blocks: List of all blocks in the workflow definition
|
||||
|
||||
Returns:
|
||||
Set of block labels that are downstream of any updated block
|
||||
"""
|
||||
# Build adjacency graph: block_label -> set of next block labels
|
||||
adjacency: dict[str, set[str]] = {}
|
||||
for block in blocks:
|
||||
if not block.label:
|
||||
continue
|
||||
adjacency[block.label] = set()
|
||||
# Check for conditional blocks (using duck typing for testability)
|
||||
# ConditionalBlock has ordered_branches with multiple next_block_label targets
|
||||
if hasattr(block, "ordered_branches"):
|
||||
for branch in block.ordered_branches:
|
||||
if branch.next_block_label:
|
||||
adjacency[block.label].add(branch.next_block_label)
|
||||
elif block.next_block_label:
|
||||
adjacency[block.label].add(block.next_block_label)
|
||||
|
||||
# BFS from all updated blocks to find downstream blocks
|
||||
downstream: set[str] = set()
|
||||
queue: deque[str] = deque(updated_labels)
|
||||
visited: set[str] = set(updated_labels)
|
||||
|
||||
while queue:
|
||||
current = queue.popleft()
|
||||
for next_label in adjacency.get(current, set()):
|
||||
if next_label not in visited:
|
||||
visited.add(next_label)
|
||||
downstream.add(next_label)
|
||||
queue.append(next_label)
|
||||
|
||||
return downstream
|
||||
|
||||
|
||||
def _make_workflow_script_cache_key(
|
||||
organization_id: str,
|
||||
workflow_permanent_id: str,
|
||||
@@ -285,6 +232,27 @@ async def generate_workflow_script(
|
||||
cached_script: Script | None = None,
|
||||
updated_block_labels: set[str] | None = None,
|
||||
) -> None:
|
||||
# Disable script generation for workflows containing conditional blocks to avoid caching divergent paths.
|
||||
try:
|
||||
all_blocks = get_all_blocks(workflow.workflow_definition.blocks)
|
||||
has_conditional = any(block.block_type == BlockType.CONDITIONAL for block in all_blocks)
|
||||
except Exception:
|
||||
has_conditional = False
|
||||
LOG.warning(
|
||||
"Failed to inspect workflow blocks for conditional types; continuing with script generation",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if has_conditional:
|
||||
LOG.info(
|
||||
"Skipping script generation for workflow containing conditional blocks",
|
||||
workflow_id=workflow.workflow_id,
|
||||
workflow_run_id=workflow_run.workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
LOG.info(
|
||||
"Generating script for workflow",
|
||||
@@ -313,18 +281,6 @@ async def generate_workflow_script(
|
||||
updated_block_labels.update(missing_labels)
|
||||
updated_block_labels.add(settings.WORKFLOW_START_BLOCK_LABEL)
|
||||
|
||||
# Build set of all block labels from the current workflow definition.
|
||||
# This is used to filter out stale cached blocks that no longer exist
|
||||
# (e.g., after a user deletes or renames blocks in the workflow).
|
||||
all_definition_blocks = get_all_blocks(workflow.workflow_definition.blocks)
|
||||
workflow_definition_labels = {block.label for block in all_definition_blocks if block.label}
|
||||
|
||||
# Compute blocks downstream of any updated block. These should not be preserved
|
||||
# from cache because they may depend on data from the updated upstream blocks.
|
||||
# For example, if Block A extracts data used by Block B, and A is modified,
|
||||
# then B's cached code may be stale even if B itself wasn't modified.
|
||||
downstream_of_updated = get_downstream_blocks(updated_block_labels, all_definition_blocks)
|
||||
|
||||
python_src = await generate_workflow_script_python_code(
|
||||
file_name=codegen_input.file_name,
|
||||
workflow_run_request=codegen_input.workflow_run,
|
||||
@@ -338,8 +294,6 @@ async def generate_workflow_script(
|
||||
pending=pending,
|
||||
cached_blocks=cached_block_sources,
|
||||
updated_block_labels=updated_block_labels,
|
||||
workflow_definition_labels=workflow_definition_labels,
|
||||
downstream_of_updated=downstream_of_updated,
|
||||
)
|
||||
except Exception:
|
||||
LOG.error("Failed to generate workflow script source", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user