393 lines
15 KiB
Python
393 lines
15 KiB
Python
import base64
|
|
from collections import deque
|
|
from typing import TYPE_CHECKING
|
|
|
|
import structlog
|
|
from cachetools import TTLCache
|
|
from jinja2.sandbox import SandboxedEnvironment
|
|
|
|
from skyvern.config import settings
|
|
from skyvern.core.script_generations.generate_script import ScriptBlockSource, generate_workflow_script_python_code
|
|
from skyvern.core.script_generations.transform_workflow_run import transform_workflow_run_to_code_gen_input
|
|
from skyvern.forge import app
|
|
from skyvern.forge.sdk.core import skyvern_context
|
|
from skyvern.forge.sdk.workflow.models.block import get_all_blocks
|
|
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
|
|
from skyvern.schemas.scripts import FileEncoding, Script, ScriptFileCreate, ScriptStatus
|
|
from skyvern.services import script_service
|
|
|
|
if TYPE_CHECKING:
|
|
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
|
|
|
|
LOG = structlog.get_logger()
|
|
jinja_sandbox_env = SandboxedEnvironment()
|
|
|
|
# Cache for workflow scripts - only stores non-None results
|
|
_workflow_script_cache: TTLCache[tuple, "Script"] = TTLCache(maxsize=128, ttl=60 * 60)
|
|
|
|
|
|
def get_downstream_blocks(
|
|
updated_labels: set[str],
|
|
blocks: list["BlockTypeVar"],
|
|
) -> set[str]:
|
|
"""
|
|
Get all blocks that are downstream of any updated block in the DAG.
|
|
|
|
Uses BFS to find all blocks reachable from the updated blocks by following
|
|
next_block_label edges. This is used to invalidate cached blocks that may
|
|
depend on data from updated upstream blocks.
|
|
|
|
Args:
|
|
updated_labels: Set of block labels that have been updated
|
|
blocks: List of all blocks in the workflow definition
|
|
|
|
Returns:
|
|
Set of block labels that are downstream of any updated block
|
|
"""
|
|
# Build adjacency graph: block_label -> set of next block labels
|
|
adjacency: dict[str, set[str]] = {}
|
|
for block in blocks:
|
|
if not block.label:
|
|
continue
|
|
adjacency[block.label] = set()
|
|
# Check for conditional blocks (using duck typing for testability)
|
|
# ConditionalBlock has ordered_branches with multiple next_block_label targets
|
|
if hasattr(block, "ordered_branches"):
|
|
for branch in block.ordered_branches:
|
|
if branch.next_block_label:
|
|
adjacency[block.label].add(branch.next_block_label)
|
|
elif block.next_block_label:
|
|
adjacency[block.label].add(block.next_block_label)
|
|
|
|
# BFS from all updated blocks to find downstream blocks
|
|
downstream: set[str] = set()
|
|
queue: deque[str] = deque(updated_labels)
|
|
visited: set[str] = set(updated_labels)
|
|
|
|
while queue:
|
|
current = queue.popleft()
|
|
for next_label in adjacency.get(current, set()):
|
|
if next_label not in visited:
|
|
visited.add(next_label)
|
|
downstream.add(next_label)
|
|
queue.append(next_label)
|
|
|
|
return downstream
|
|
|
|
|
|
def _make_workflow_script_cache_key(
|
|
organization_id: str,
|
|
workflow_permanent_id: str,
|
|
cache_key_value: str,
|
|
workflow_run_id: str | None = None,
|
|
cache_key: str | None = None,
|
|
statuses: list[ScriptStatus] | None = None,
|
|
) -> tuple:
|
|
"""Create a hashable cache key from the function arguments."""
|
|
# Convert list to tuple for hashability
|
|
statuses_key = tuple(statuses) if statuses else None
|
|
return (organization_id, workflow_permanent_id, cache_key_value, workflow_run_id, cache_key, statuses_key)
|
|
|
|
|
|
async def generate_or_update_pending_workflow_script(
|
|
workflow_run: WorkflowRun,
|
|
workflow: Workflow,
|
|
) -> None:
|
|
organization_id = workflow.organization_id
|
|
context = skyvern_context.current()
|
|
if not context:
|
|
return
|
|
script_id = context.script_id
|
|
script = None
|
|
if script_id:
|
|
script = await app.DATABASE.get_script(script_id=script_id, organization_id=organization_id)
|
|
|
|
if not script:
|
|
script = await app.DATABASE.create_script(organization_id=organization_id, run_id=workflow_run.workflow_run_id)
|
|
if context:
|
|
context.script_id = script.script_id
|
|
context.script_revision_id = script.script_revision_id
|
|
|
|
_, rendered_cache_key_value = await get_workflow_script(
|
|
workflow=workflow,
|
|
workflow_run=workflow_run,
|
|
status=ScriptStatus.pending,
|
|
)
|
|
await generate_workflow_script(
|
|
workflow_run=workflow_run,
|
|
workflow=workflow,
|
|
script=script,
|
|
rendered_cache_key_value=rendered_cache_key_value,
|
|
pending=True,
|
|
cached_script=script,
|
|
)
|
|
|
|
|
|
async def get_workflow_script(
|
|
workflow: Workflow,
|
|
workflow_run: WorkflowRun,
|
|
block_labels: list[str] | None = None,
|
|
status: ScriptStatus = ScriptStatus.published,
|
|
) -> tuple[Script | None, str]:
|
|
"""
|
|
Check if there's a related workflow script that should be used instead of running the workflow.
|
|
Returns the tuple of (script, rendered_cache_key_value).
|
|
"""
|
|
cache_key = workflow.cache_key or ""
|
|
rendered_cache_key_value = ""
|
|
|
|
try:
|
|
parameter_tuples = await app.DATABASE.get_workflow_run_parameters(
|
|
workflow_run_id=workflow_run.workflow_run_id,
|
|
)
|
|
parameters = {wf_param.key: run_param.value for wf_param, run_param in parameter_tuples}
|
|
|
|
rendered_cache_key_value = jinja_sandbox_env.from_string(cache_key).render(parameters)
|
|
|
|
if block_labels:
|
|
# Do not generate script or run script if block_labels is provided
|
|
return None, rendered_cache_key_value
|
|
|
|
# Check if there are existing cached scripts for this workflow + cache_key_value
|
|
existing_script = await get_workflow_script_by_cache_key_value(
|
|
organization_id=workflow.organization_id,
|
|
workflow_permanent_id=workflow.workflow_permanent_id,
|
|
cache_key_value=rendered_cache_key_value,
|
|
statuses=[status],
|
|
use_cache=True,
|
|
)
|
|
|
|
if existing_script:
|
|
LOG.info(
|
|
"Found cached script for workflow",
|
|
workflow_id=workflow.workflow_id,
|
|
cache_key_value=rendered_cache_key_value,
|
|
workflow_run_id=workflow_run.workflow_run_id,
|
|
)
|
|
return existing_script, rendered_cache_key_value
|
|
|
|
return None, rendered_cache_key_value
|
|
|
|
except Exception as e:
|
|
LOG.warning(
|
|
"Failed to check for workflow script, proceeding with normal workflow execution",
|
|
workflow_id=workflow.workflow_id,
|
|
workflow_run_id=workflow_run.workflow_run_id,
|
|
error=str(e),
|
|
exc_info=True,
|
|
)
|
|
return None, rendered_cache_key_value
|
|
|
|
|
|
async def get_workflow_script_by_cache_key_value(
|
|
organization_id: str,
|
|
workflow_permanent_id: str,
|
|
cache_key_value: str,
|
|
workflow_run_id: str | None = None,
|
|
cache_key: str | None = None,
|
|
statuses: list[ScriptStatus] | None = None,
|
|
use_cache: bool = False,
|
|
) -> Script | None:
|
|
if use_cache:
|
|
cache_key_tuple = _make_workflow_script_cache_key(
|
|
organization_id=organization_id,
|
|
workflow_permanent_id=workflow_permanent_id,
|
|
cache_key_value=cache_key_value,
|
|
workflow_run_id=workflow_run_id,
|
|
cache_key=cache_key,
|
|
statuses=statuses,
|
|
)
|
|
# Check cache first
|
|
if cache_key_tuple in _workflow_script_cache:
|
|
return _workflow_script_cache[cache_key_tuple]
|
|
|
|
# Cache miss - fetch from database
|
|
result = await app.DATABASE.get_workflow_script_by_cache_key_value(
|
|
organization_id=organization_id,
|
|
workflow_permanent_id=workflow_permanent_id,
|
|
cache_key_value=cache_key_value,
|
|
workflow_run_id=workflow_run_id,
|
|
cache_key=cache_key,
|
|
statuses=statuses,
|
|
)
|
|
|
|
# Only cache non-None results
|
|
if result is not None:
|
|
_workflow_script_cache[cache_key_tuple] = result
|
|
|
|
return result
|
|
|
|
return await app.DATABASE.get_workflow_script_by_cache_key_value(
|
|
organization_id=organization_id,
|
|
workflow_permanent_id=workflow_permanent_id,
|
|
cache_key_value=cache_key_value,
|
|
workflow_run_id=workflow_run_id,
|
|
cache_key=cache_key,
|
|
statuses=statuses,
|
|
)
|
|
|
|
|
|
async def _load_cached_script_block_sources(
|
|
script: Script,
|
|
organization_id: str,
|
|
) -> dict[str, ScriptBlockSource]:
|
|
"""
|
|
Load existing script block sources (code + metadata) for a script revision so they can be reused.
|
|
"""
|
|
cached_blocks: dict[str, ScriptBlockSource] = {}
|
|
|
|
script_blocks = await app.DATABASE.get_script_blocks_by_script_revision_id(
|
|
script_revision_id=script.script_revision_id,
|
|
organization_id=organization_id,
|
|
)
|
|
|
|
for script_block in script_blocks:
|
|
if not script_block.script_block_label:
|
|
continue
|
|
|
|
code_str: str | None = None
|
|
if script_block.script_file_id:
|
|
script_file = await app.DATABASE.get_script_file_by_id(
|
|
script_revision_id=script.script_revision_id,
|
|
file_id=script_block.script_file_id,
|
|
organization_id=organization_id,
|
|
)
|
|
if script_file and script_file.artifact_id:
|
|
artifact = await app.DATABASE.get_artifact_by_id(script_file.artifact_id, organization_id)
|
|
if artifact:
|
|
file_content = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
|
|
if isinstance(file_content, bytes):
|
|
code_str = file_content.decode("utf-8")
|
|
elif isinstance(file_content, str):
|
|
code_str = file_content
|
|
|
|
if not code_str:
|
|
continue
|
|
|
|
cached_blocks[script_block.script_block_label] = ScriptBlockSource(
|
|
label=script_block.script_block_label,
|
|
code=code_str,
|
|
run_signature=script_block.run_signature,
|
|
workflow_run_id=script_block.workflow_run_id,
|
|
workflow_run_block_id=script_block.workflow_run_block_id,
|
|
input_fields=script_block.input_fields,
|
|
)
|
|
|
|
return cached_blocks
|
|
|
|
|
|
async def generate_workflow_script(
|
|
workflow_run: WorkflowRun,
|
|
workflow: Workflow,
|
|
script: Script,
|
|
rendered_cache_key_value: str,
|
|
pending: bool = False,
|
|
cached_script: Script | None = None,
|
|
updated_block_labels: set[str] | None = None,
|
|
) -> None:
|
|
try:
|
|
LOG.info(
|
|
"Generating script for workflow",
|
|
workflow_run_id=workflow_run.workflow_run_id,
|
|
workflow_id=workflow.workflow_id,
|
|
workflow_name=workflow.title,
|
|
cache_key_value=rendered_cache_key_value,
|
|
)
|
|
cached_block_sources: dict[str, ScriptBlockSource] = {}
|
|
if cached_script:
|
|
cached_block_sources = await _load_cached_script_block_sources(cached_script, workflow.organization_id)
|
|
|
|
codegen_input = await transform_workflow_run_to_code_gen_input(
|
|
workflow_run_id=workflow_run.workflow_run_id,
|
|
organization_id=workflow.organization_id,
|
|
)
|
|
|
|
block_labels = [block.get("label") for block in codegen_input.workflow_blocks if block.get("label")]
|
|
|
|
if updated_block_labels is None:
|
|
updated_block_labels = {label for label in block_labels if label}
|
|
else:
|
|
updated_block_labels = set(updated_block_labels)
|
|
|
|
missing_labels = {label for label in block_labels if label and label not in cached_block_sources}
|
|
updated_block_labels.update(missing_labels)
|
|
updated_block_labels.add(settings.WORKFLOW_START_BLOCK_LABEL)
|
|
|
|
# Build set of all block labels from the current workflow definition.
|
|
# This is used to filter out stale cached blocks that no longer exist
|
|
# (e.g., after a user deletes or renames blocks in the workflow).
|
|
all_definition_blocks = get_all_blocks(workflow.workflow_definition.blocks)
|
|
workflow_definition_labels = {block.label for block in all_definition_blocks if block.label}
|
|
|
|
# Compute blocks downstream of any updated block. These should not be preserved
|
|
# from cache because they may depend on data from the updated upstream blocks.
|
|
# For example, if Block A extracts data used by Block B, and A is modified,
|
|
# then B's cached code may be stale even if B itself wasn't modified.
|
|
downstream_of_updated = get_downstream_blocks(updated_block_labels, all_definition_blocks)
|
|
|
|
python_src = await generate_workflow_script_python_code(
|
|
file_name=codegen_input.file_name,
|
|
workflow_run_request=codegen_input.workflow_run,
|
|
workflow=codegen_input.workflow,
|
|
blocks=codegen_input.workflow_blocks,
|
|
actions_by_task=codegen_input.actions_by_task,
|
|
task_v2_child_blocks=codegen_input.task_v2_child_blocks,
|
|
organization_id=workflow.organization_id,
|
|
script_id=script.script_id,
|
|
script_revision_id=script.script_revision_id,
|
|
pending=pending,
|
|
cached_blocks=cached_block_sources,
|
|
updated_block_labels=updated_block_labels,
|
|
workflow_definition_labels=workflow_definition_labels,
|
|
downstream_of_updated=downstream_of_updated,
|
|
)
|
|
except Exception:
|
|
LOG.error("Failed to generate workflow script source", exc_info=True)
|
|
return
|
|
|
|
# 4) Persist script and files, then record mapping
|
|
content_bytes = python_src.encode("utf-8")
|
|
content_b64 = base64.b64encode(content_bytes).decode("utf-8")
|
|
files = [
|
|
ScriptFileCreate(
|
|
path="main.py",
|
|
content=content_b64,
|
|
encoding=FileEncoding.BASE64,
|
|
mime_type="text/x-python",
|
|
)
|
|
]
|
|
|
|
# Upload script file(s) as artifacts and create rows
|
|
await script_service.build_file_tree(
|
|
files=files,
|
|
organization_id=workflow.organization_id,
|
|
script_id=script.script_id,
|
|
script_version=script.version,
|
|
script_revision_id=script.script_revision_id,
|
|
pending=pending,
|
|
)
|
|
|
|
# check if an existing drfat workflow script exists for this workflow run
|
|
existing_pending_workflow_script = None
|
|
status = ScriptStatus.published
|
|
if pending:
|
|
status = ScriptStatus.pending
|
|
existing_pending_workflow_script = await app.DATABASE.get_workflow_script(
|
|
organization_id=workflow.organization_id,
|
|
workflow_permanent_id=workflow.workflow_permanent_id,
|
|
workflow_run_id=workflow_run.workflow_run_id,
|
|
statuses=[status],
|
|
)
|
|
if not existing_pending_workflow_script:
|
|
# Record the workflow->script mapping for cache lookup
|
|
await app.DATABASE.create_workflow_script(
|
|
organization_id=workflow.organization_id,
|
|
script_id=script.script_id,
|
|
workflow_permanent_id=workflow.workflow_permanent_id,
|
|
cache_key=workflow.cache_key or "",
|
|
cache_key_value=rendered_cache_key_value,
|
|
workflow_id=workflow.workflow_id,
|
|
workflow_run_id=workflow_run.workflow_run_id,
|
|
status=status,
|
|
)
|