From b64c7d4032722ea1d48720432643f1dac58582b4 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Mon, 2 Feb 2026 21:15:34 -0800 Subject: [PATCH] script to profile a workflow run (#4608) --- scripts/profile_workflow_run.py | 206 ++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 scripts/profile_workflow_run.py diff --git a/scripts/profile_workflow_run.py b/scripts/profile_workflow_run.py new file mode 100644 index 00000000..e90debf4 --- /dev/null +++ b/scripts/profile_workflow_run.py @@ -0,0 +1,206 @@ +""" +Script to profile a workflow run by collecting and displaying all key timestamps. + +Usage: + python scripts/profile_workflow_run.py +""" + +import asyncio +from dataclasses import dataclass +from datetime import datetime +from typing import Annotated + +import typer +from sqlalchemy import select + +from skyvern.forge import app +from skyvern.forge.forge_app_initializer import start_forge_app +from skyvern.forge.sdk.db.models import ( + StepModel, + TaskModel, + WorkflowRunBlockModel, + WorkflowRunModel, +) + + +@dataclass +class TimestampEntry: + """Represents a single timestamp entry for profiling.""" + + timestamp: datetime + entity_type: str # "workflow_run", "workflow_run_block", "task", "step" + entity_id: str + field_name: str # "created_at", "started_at", "finished_at", etc. + label: str | None = None # For blocks with labels + status: str | None = None + + def __str__(self) -> str: + label_str = f" [{self.label}]" if self.label else "" + status_str = f" ({self.status})" if self.status else "" + return f"{self.timestamp.isoformat()} | {self.entity_type:20}{label_str} | {self.field_name:12} | {self.entity_id}{status_str}" + + +async def collect_timestamps(workflow_run_id: str) -> list[TimestampEntry]: + """Collect all timestamps from the workflow run and its children.""" + entries: list[TimestampEntry] = [] + + async with app.DATABASE.Session() as session: + # 1. Fetch the workflow run + workflow_run = ( + await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id)) + ).first() + + if not workflow_run: + raise ValueError(f"Workflow run not found: {workflow_run_id}") + + # Add workflow run timestamps + for field in ["created_at", "queued_at", "started_at", "finished_at"]: + ts = getattr(workflow_run, field, None) + if ts: + entries.append( + TimestampEntry( + timestamp=ts, + entity_type="workflow_run", + entity_id=workflow_run_id, + field_name=field, + status=workflow_run.status, + ) + ) + + # 2. Fetch all workflow run blocks + workflow_run_blocks = ( + await session.scalars( + select(WorkflowRunBlockModel) + .filter_by(workflow_run_id=workflow_run_id) + .order_by(WorkflowRunBlockModel.created_at) + ) + ).all() + + for block in workflow_run_blocks: + for field in ["created_at", "queued_at", "started_at", "finished_at", "modified_at"]: + ts = getattr(block, field, None) + if ts: + entries.append( + TimestampEntry( + timestamp=ts, + entity_type="workflow_run_block", + entity_id=block.workflow_run_block_id, + field_name=field, + label=block.label, + status=block.status, + ) + ) + + # 3. Fetch all tasks for this workflow run + tasks = ( + await session.scalars( + select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at) + ) + ).all() + + task_ids = [] + for task in tasks: + task_ids.append(task.task_id) + for field in ["created_at", "queued_at", "started_at", "finished_at"]: + ts = getattr(task, field, None) + if ts: + entries.append( + TimestampEntry( + timestamp=ts, + entity_type="task", + entity_id=task.task_id, + field_name=field, + status=task.status, + ) + ) + + # 4. Fetch all steps for all tasks + if task_ids: + steps = ( + await session.scalars( + select(StepModel).filter(StepModel.task_id.in_(task_ids)).order_by(StepModel.created_at) + ) + ).all() + + for step in steps: + for field in ["created_at", "finished_at"]: + ts = getattr(step, field, None) + if ts: + entries.append( + TimestampEntry( + timestamp=ts, + entity_type="step", + entity_id=step.step_id, + field_name=field, + status=step.status, + ) + ) + + return entries + + +def print_profile(entries: list[TimestampEntry]) -> None: + """Print the profiling results sorted by timestamp.""" + # Sort by timestamp + sorted_entries = sorted(entries, key=lambda e: e.timestamp) + + if not sorted_entries: + print("No timestamps found.") + return + + print("\n" + "=" * 120) + print("WORKFLOW RUN PROFILE") + print("=" * 120) + print(f"{'Timestamp':<30} | {'Entity Type':<25} | {'Field':<12} | {'Entity ID'}") + print("-" * 120) + + first_ts = sorted_entries[0].timestamp + for entry in sorted_entries: + # Calculate relative time from first timestamp + delta = entry.timestamp - first_ts + delta_str = f"+{delta.total_seconds():>10.3f}s" + + label_str = f" [{entry.label}]" if entry.label else "" + status_str = f" ({entry.status})" if entry.status else "" + + print( + f"{entry.timestamp.isoformat():<30} {delta_str} | {entry.entity_type:<25}{label_str:<15} | {entry.field_name:<12} | {entry.entity_id[:36]}{status_str}" + ) + + print("-" * 120) + + # Print summary + total_duration = sorted_entries[-1].timestamp - sorted_entries[0].timestamp + print(f"\nTotal Duration: {total_duration.total_seconds():.3f} seconds") + + # Count entities + entity_counts: dict[str, int] = {} + for entry in sorted_entries: + if entry.field_name == "created_at": + entity_counts[entry.entity_type] = entity_counts.get(entry.entity_type, 0) + 1 + + print("\nEntity Counts:") + for entity_type, count in entity_counts.items(): + print(f" {entity_type}: {count}") + + print("=" * 120 + "\n") + + +async def profile_workflow_run(workflow_run_id: str) -> None: + """Main function to profile a workflow run.""" + print(f"Profiling workflow run: {workflow_run_id}") + + entries = await collect_timestamps(workflow_run_id) + print_profile(entries) + + +def main( + workflow_run_id: Annotated[str, typer.Argument(help="The workflow run ID to profile")], +) -> None: + """Profile a workflow run by collecting and displaying all key timestamps.""" + start_forge_app() + asyncio.run(profile_workflow_run(workflow_run_id)) + + +if __name__ == "__main__": + typer.run(main)