script to profile a workflow run (#4608)

This commit is contained in:
Shuchang Zheng
2026-02-02 21:15:34 -08:00
committed by GitHub
parent 327e6efa1a
commit b64c7d4032

View File

@@ -0,0 +1,206 @@
"""
Script to profile a workflow run by collecting and displaying all key timestamps.
Usage:
python scripts/profile_workflow_run.py <workflow_run_id>
"""
import asyncio
from dataclasses import dataclass
from datetime import datetime
from typing import Annotated
import typer
from sqlalchemy import select
from skyvern.forge import app
from skyvern.forge.forge_app_initializer import start_forge_app
from skyvern.forge.sdk.db.models import (
StepModel,
TaskModel,
WorkflowRunBlockModel,
WorkflowRunModel,
)
@dataclass
class TimestampEntry:
"""Represents a single timestamp entry for profiling."""
timestamp: datetime
entity_type: str # "workflow_run", "workflow_run_block", "task", "step"
entity_id: str
field_name: str # "created_at", "started_at", "finished_at", etc.
label: str | None = None # For blocks with labels
status: str | None = None
def __str__(self) -> str:
label_str = f" [{self.label}]" if self.label else ""
status_str = f" ({self.status})" if self.status else ""
return f"{self.timestamp.isoformat()} | {self.entity_type:20}{label_str} | {self.field_name:12} | {self.entity_id}{status_str}"
async def collect_timestamps(workflow_run_id: str) -> list[TimestampEntry]:
"""Collect all timestamps from the workflow run and its children."""
entries: list[TimestampEntry] = []
async with app.DATABASE.Session() as session:
# 1. Fetch the workflow run
workflow_run = (
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
).first()
if not workflow_run:
raise ValueError(f"Workflow run not found: {workflow_run_id}")
# Add workflow run timestamps
for field in ["created_at", "queued_at", "started_at", "finished_at"]:
ts = getattr(workflow_run, field, None)
if ts:
entries.append(
TimestampEntry(
timestamp=ts,
entity_type="workflow_run",
entity_id=workflow_run_id,
field_name=field,
status=workflow_run.status,
)
)
# 2. Fetch all workflow run blocks
workflow_run_blocks = (
await session.scalars(
select(WorkflowRunBlockModel)
.filter_by(workflow_run_id=workflow_run_id)
.order_by(WorkflowRunBlockModel.created_at)
)
).all()
for block in workflow_run_blocks:
for field in ["created_at", "queued_at", "started_at", "finished_at", "modified_at"]:
ts = getattr(block, field, None)
if ts:
entries.append(
TimestampEntry(
timestamp=ts,
entity_type="workflow_run_block",
entity_id=block.workflow_run_block_id,
field_name=field,
label=block.label,
status=block.status,
)
)
# 3. Fetch all tasks for this workflow run
tasks = (
await session.scalars(
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at)
)
).all()
task_ids = []
for task in tasks:
task_ids.append(task.task_id)
for field in ["created_at", "queued_at", "started_at", "finished_at"]:
ts = getattr(task, field, None)
if ts:
entries.append(
TimestampEntry(
timestamp=ts,
entity_type="task",
entity_id=task.task_id,
field_name=field,
status=task.status,
)
)
# 4. Fetch all steps for all tasks
if task_ids:
steps = (
await session.scalars(
select(StepModel).filter(StepModel.task_id.in_(task_ids)).order_by(StepModel.created_at)
)
).all()
for step in steps:
for field in ["created_at", "finished_at"]:
ts = getattr(step, field, None)
if ts:
entries.append(
TimestampEntry(
timestamp=ts,
entity_type="step",
entity_id=step.step_id,
field_name=field,
status=step.status,
)
)
return entries
def print_profile(entries: list[TimestampEntry]) -> None:
"""Print the profiling results sorted by timestamp."""
# Sort by timestamp
sorted_entries = sorted(entries, key=lambda e: e.timestamp)
if not sorted_entries:
print("No timestamps found.")
return
print("\n" + "=" * 120)
print("WORKFLOW RUN PROFILE")
print("=" * 120)
print(f"{'Timestamp':<30} | {'Entity Type':<25} | {'Field':<12} | {'Entity ID'}")
print("-" * 120)
first_ts = sorted_entries[0].timestamp
for entry in sorted_entries:
# Calculate relative time from first timestamp
delta = entry.timestamp - first_ts
delta_str = f"+{delta.total_seconds():>10.3f}s"
label_str = f" [{entry.label}]" if entry.label else ""
status_str = f" ({entry.status})" if entry.status else ""
print(
f"{entry.timestamp.isoformat():<30} {delta_str} | {entry.entity_type:<25}{label_str:<15} | {entry.field_name:<12} | {entry.entity_id[:36]}{status_str}"
)
print("-" * 120)
# Print summary
total_duration = sorted_entries[-1].timestamp - sorted_entries[0].timestamp
print(f"\nTotal Duration: {total_duration.total_seconds():.3f} seconds")
# Count entities
entity_counts: dict[str, int] = {}
for entry in sorted_entries:
if entry.field_name == "created_at":
entity_counts[entry.entity_type] = entity_counts.get(entry.entity_type, 0) + 1
print("\nEntity Counts:")
for entity_type, count in entity_counts.items():
print(f" {entity_type}: {count}")
print("=" * 120 + "\n")
async def profile_workflow_run(workflow_run_id: str) -> None:
"""Main function to profile a workflow run."""
print(f"Profiling workflow run: {workflow_run_id}")
entries = await collect_timestamps(workflow_run_id)
print_profile(entries)
def main(
workflow_run_id: Annotated[str, typer.Argument(help="The workflow run ID to profile")],
) -> None:
"""Profile a workflow run by collecting and displaying all key timestamps."""
start_forge_app()
asyncio.run(profile_workflow_run(workflow_run_id))
if __name__ == "__main__":
typer.run(main)