cache invalidation update - only delete cached code for impacted block and blocks after the impacted (#3908)

This commit is contained in:
Shuchang Zheng
2025-11-05 15:26:11 +08:00
committed by GitHub
parent 3c3b5c2db9
commit 02fc0d9dda
2 changed files with 248 additions and 2 deletions

View File

@@ -4612,6 +4612,7 @@ class AgentDB:
run_signature: str | None = None, run_signature: str | None = None,
workflow_run_id: str | None = None, workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None, workflow_run_block_id: str | None = None,
clear_run_signature: bool = False,
) -> ScriptBlock: ) -> ScriptBlock:
async with self.Session() as session: async with self.Session() as session:
script_block = ( script_block = (
@@ -4624,7 +4625,9 @@ class AgentDB:
if script_block: if script_block:
if script_file_id is not None: if script_file_id is not None:
script_block.script_file_id = script_file_id script_block.script_file_id = script_file_id
if run_signature is not None: if clear_run_signature:
script_block.run_signature = None
elif run_signature is not None:
script_block.run_signature = run_signature script_block.run_signature = run_signature
if workflow_run_id is not None: if workflow_run_id is not None:
script_block.workflow_run_id = workflow_run_id script_block.workflow_run_id = workflow_run_id

View File

@@ -5,6 +5,8 @@ import os
import textwrap import textwrap
import uuid import uuid
from collections import deque from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Literal, cast from typing import Any, Literal, cast
@@ -104,7 +106,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus, WorkflowRunStatus,
) )
from skyvern.schemas.runs import ProxyLocation, RunStatus, RunType, WorkflowRunRequest, WorkflowRunResponse from skyvern.schemas.runs import ProxyLocation, RunStatus, RunType, WorkflowRunRequest, WorkflowRunResponse
from skyvern.schemas.scripts import ScriptStatus, WorkflowScript from skyvern.schemas.scripts import Script, ScriptBlock, ScriptStatus, WorkflowScript
from skyvern.schemas.workflows import ( from skyvern.schemas.workflows import (
BLOCK_YAML_TYPES, BLOCK_YAML_TYPES,
BlockResult, BlockResult,
@@ -123,6 +125,28 @@ LOG = structlog.get_logger()
DEFAULT_FIRST_BLOCK_LABEL = "block_1" DEFAULT_FIRST_BLOCK_LABEL = "block_1"
DEFAULT_WORKFLOW_TITLE = "New Workflow" DEFAULT_WORKFLOW_TITLE = "New Workflow"
CacheInvalidationReason = Literal["updated_block", "new_block", "removed_block"]
@dataclass
class CacheInvalidationPlan:
reason: CacheInvalidationReason | None = None
label: str | None = None
previous_index: int | None = None
new_index: int | None = None
block_labels_to_disable: list[str] = field(default_factory=list)
@property
def has_targets(self) -> bool:
return bool(self.block_labels_to_disable)
@dataclass
class CachedScriptBlocks:
workflow_script: WorkflowScript
script: Script
blocks_to_clear: list[ScriptBlock]
def _get_workflow_definition_core_data(workflow_definition: WorkflowDefinition) -> dict[str, Any]: def _get_workflow_definition_core_data(workflow_definition: WorkflowDefinition) -> dict[str, Any]:
""" """
@@ -179,6 +203,134 @@ def _get_workflow_definition_core_data(workflow_definition: WorkflowDefinition)
class WorkflowService: class WorkflowService:
@staticmethod
def _determine_cache_invalidation(
previous_blocks: list[dict[str, Any]],
new_blocks: list[dict[str, Any]],
) -> CacheInvalidationPlan:
"""Return which block index triggered the change and the labels that need cache invalidation."""
plan = CacheInvalidationPlan()
prev_labels: list[str] = []
for blocks in previous_blocks:
label = blocks.get("label")
if label and isinstance(label, str):
prev_labels.append(label)
new_labels: list[str] = []
for blocks in new_blocks:
label = blocks.get("label")
if label and isinstance(label, str):
new_labels.append(label)
for idx, (prev_block, new_block) in enumerate(zip(previous_blocks, new_blocks)):
prev_label = prev_block.get("label")
new_label = new_block.get("label")
if prev_label and prev_label == new_label and prev_block != new_block:
plan.reason = "updated_block"
plan.label = new_label
plan.previous_index = idx
break
if plan.reason is None:
previous_label_set = set(prev_labels)
for idx, label in enumerate(new_labels):
if label and label not in previous_label_set:
plan.reason = "new_block"
plan.label = label
plan.new_index = idx
plan.previous_index = min(idx, len(prev_labels))
break
if plan.reason is None:
new_label_set = set(new_labels)
for idx, label in enumerate(prev_labels):
if label not in new_label_set:
plan.reason = "removed_block"
plan.label = label
plan.previous_index = idx
break
if plan.reason == "removed_block":
new_label_set = set(new_labels)
plan.block_labels_to_disable = [label for label in prev_labels if label and label not in new_label_set]
elif plan.previous_index is not None:
plan.block_labels_to_disable = prev_labels[plan.previous_index :]
return plan
async def _partition_cached_blocks(
self,
*,
organization_id: str,
candidates: Sequence[WorkflowScript],
block_labels_to_disable: Sequence[str],
) -> tuple[list[CachedScriptBlocks], list[CachedScriptBlocks]]:
"""Split cached scripts into published vs draft buckets and collect blocks that should be cleared."""
cached_groups: list[CachedScriptBlocks] = []
published_groups: list[CachedScriptBlocks] = []
target_labels = set(block_labels_to_disable)
for candidate in candidates:
script = await app.DATABASE.get_script(
script_id=candidate.script_id,
organization_id=organization_id,
)
if not script:
continue
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
script_revision_id=script.script_revision_id,
organization_id=organization_id,
)
blocks_to_clear = [
block for block in script_blocks if block.script_block_label in target_labels and block.run_signature
]
if not blocks_to_clear:
continue
group = CachedScriptBlocks(workflow_script=candidate, script=script, blocks_to_clear=blocks_to_clear)
if candidate.status == ScriptStatus.published:
published_groups.append(group)
else:
cached_groups.append(group)
return cached_groups, published_groups
async def _clear_cached_block_groups(
self,
*,
organization_id: str,
workflow: Workflow,
previous_workflow: Workflow,
plan: CacheInvalidationPlan,
groups: Sequence[CachedScriptBlocks],
) -> None:
"""Remove cached run signatures for the supplied block groups to force regeneration."""
for group in groups:
for block in group.blocks_to_clear:
await app.DATABASE.update_script_block(
script_block_id=block.script_block_id,
organization_id=organization_id,
clear_run_signature=True,
)
LOG.info(
"Cleared cached script blocks after workflow block change",
workflow_id=workflow.workflow_id,
workflow_permanent_id=previous_workflow.workflow_permanent_id,
organization_id=organization_id,
previous_version=previous_workflow.version,
new_version=workflow.version,
invalidate_reason=plan.reason,
invalidate_label=plan.label,
invalidate_index_prev=plan.previous_index,
invalidate_index_new=plan.new_index,
script_id=group.script.script_id,
script_revision_id=group.script.script_revision_id,
cleared_block_labels=[block.script_block_label for block in group.blocks_to_clear],
cleared_block_count=len(group.blocks_to_clear),
)
@staticmethod @staticmethod
def _collect_extracted_information(value: Any) -> list[Any]: def _collect_extracted_information(value: Any) -> list[Any]:
"""Recursively collect extracted_information values from nested outputs.""" """Recursively collect extracted_information values from nested outputs."""
@@ -1271,6 +1423,8 @@ class WorkflowService:
ignore_version=workflow.version, ignore_version=workflow.version,
) )
current_definition: dict[str, Any] = {}
new_definition: dict[str, Any] = {}
if previous_valid_workflow: if previous_valid_workflow:
current_definition = _get_workflow_definition_core_data(previous_valid_workflow.workflow_definition) current_definition = _get_workflow_definition_core_data(previous_valid_workflow.workflow_definition)
new_definition = _get_workflow_definition_core_data(workflow_definition) new_definition = _get_workflow_definition_core_data(workflow_definition)
@@ -1279,11 +1433,100 @@ class WorkflowService:
has_changes = False has_changes = False
if previous_valid_workflow and has_changes and delete_script: if previous_valid_workflow and has_changes and delete_script:
plan = self._determine_cache_invalidation(
previous_blocks=current_definition.get("blocks", []),
new_blocks=new_definition.get("blocks", []),
)
candidates = await app.DATABASE.get_workflow_scripts_by_permanent_id( candidates = await app.DATABASE.get_workflow_scripts_by_permanent_id(
organization_id=organization_id, organization_id=organization_id,
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id, workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
) )
if plan.has_targets:
cached_groups, published_groups = await self._partition_cached_blocks(
organization_id=organization_id,
candidates=candidates,
block_labels_to_disable=plan.block_labels_to_disable,
)
if not cached_groups and not published_groups:
LOG.info(
"Workflow definition changed, no cached script blocks found after workflow block change",
workflow_id=workflow.workflow_id,
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
organization_id=organization_id,
previous_version=previous_valid_workflow.version,
new_version=workflow.version,
invalidate_reason=plan.reason,
invalidate_label=plan.label,
invalidate_index_prev=plan.previous_index,
invalidate_index_new=plan.new_index,
block_labels_to_disable=plan.block_labels_to_disable,
)
return
if published_groups and not delete_code_cache_is_ok:
LOG.info(
"Workflow definition changed, asking user if clearing published cached blocks is ok",
workflow_id=workflow.workflow_id,
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
organization_id=organization_id,
previous_version=previous_valid_workflow.version,
new_version=workflow.version,
invalidate_reason=plan.reason,
invalidate_label=plan.label,
invalidate_index_prev=plan.previous_index,
invalidate_index_new=plan.new_index,
block_labels_to_disable=plan.block_labels_to_disable,
to_clear_published_cnt=len(published_groups),
to_clear_non_published_cnt=len(cached_groups),
)
raise CannotUpdateWorkflowDueToCodeCache(
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
)
try:
groups_to_clear = [*cached_groups, *published_groups]
await self._clear_cached_block_groups(
organization_id=organization_id,
workflow=workflow,
previous_workflow=previous_valid_workflow,
plan=plan,
groups=groups_to_clear,
)
except Exception as e:
LOG.error(
"Failed to clear cached script blocks after workflow block change",
workflow_id=workflow.workflow_id,
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
organization_id=organization_id,
previous_version=previous_valid_workflow.version,
new_version=workflow.version,
invalidate_reason=plan.reason,
invalidate_label=plan.label,
invalidate_index_prev=plan.previous_index,
invalidate_index_new=plan.new_index,
error=str(e),
)
return
if plan.previous_index is not None:
LOG.info(
"Workflow definition changed, no cached script blocks exist to clear for workflow block change",
workflow_id=workflow.workflow_id,
workflow_permanent_id=previous_valid_workflow.workflow_permanent_id,
organization_id=organization_id,
previous_version=previous_valid_workflow.version,
new_version=workflow.version,
invalidate_reason=plan.reason,
invalidate_label=plan.label,
invalidate_index_prev=plan.previous_index,
invalidate_index_new=plan.new_index,
)
return
to_delete_published = [script for script in candidates if script.status == ScriptStatus.published] to_delete_published = [script for script in candidates if script.status == ScriptStatus.published]
to_delete = [script for script in candidates if script.status != ScriptStatus.published] to_delete = [script for script in candidates if script.status != ScriptStatus.published]