script to profile a workflow run (#4608)
This commit is contained in:
206
scripts/profile_workflow_run.py
Normal file
206
scripts/profile_workflow_run.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Script to profile a workflow run by collecting and displaying all key timestamps.
|
||||
|
||||
Usage:
|
||||
python scripts/profile_workflow_run.py <workflow_run_id>
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from sqlalchemy import select
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.forge_app_initializer import start_forge_app
|
||||
from skyvern.forge.sdk.db.models import (
|
||||
StepModel,
|
||||
TaskModel,
|
||||
WorkflowRunBlockModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimestampEntry:
|
||||
"""Represents a single timestamp entry for profiling."""
|
||||
|
||||
timestamp: datetime
|
||||
entity_type: str # "workflow_run", "workflow_run_block", "task", "step"
|
||||
entity_id: str
|
||||
field_name: str # "created_at", "started_at", "finished_at", etc.
|
||||
label: str | None = None # For blocks with labels
|
||||
status: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
label_str = f" [{self.label}]" if self.label else ""
|
||||
status_str = f" ({self.status})" if self.status else ""
|
||||
return f"{self.timestamp.isoformat()} | {self.entity_type:20}{label_str} | {self.field_name:12} | {self.entity_id}{status_str}"
|
||||
|
||||
|
||||
async def collect_timestamps(workflow_run_id: str) -> list[TimestampEntry]:
|
||||
"""Collect all timestamps from the workflow run and its children."""
|
||||
entries: list[TimestampEntry] = []
|
||||
|
||||
async with app.DATABASE.Session() as session:
|
||||
# 1. Fetch the workflow run
|
||||
workflow_run = (
|
||||
await session.scalars(select(WorkflowRunModel).filter_by(workflow_run_id=workflow_run_id))
|
||||
).first()
|
||||
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {workflow_run_id}")
|
||||
|
||||
# Add workflow run timestamps
|
||||
for field in ["created_at", "queued_at", "started_at", "finished_at"]:
|
||||
ts = getattr(workflow_run, field, None)
|
||||
if ts:
|
||||
entries.append(
|
||||
TimestampEntry(
|
||||
timestamp=ts,
|
||||
entity_type="workflow_run",
|
||||
entity_id=workflow_run_id,
|
||||
field_name=field,
|
||||
status=workflow_run.status,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Fetch all workflow run blocks
|
||||
workflow_run_blocks = (
|
||||
await session.scalars(
|
||||
select(WorkflowRunBlockModel)
|
||||
.filter_by(workflow_run_id=workflow_run_id)
|
||||
.order_by(WorkflowRunBlockModel.created_at)
|
||||
)
|
||||
).all()
|
||||
|
||||
for block in workflow_run_blocks:
|
||||
for field in ["created_at", "queued_at", "started_at", "finished_at", "modified_at"]:
|
||||
ts = getattr(block, field, None)
|
||||
if ts:
|
||||
entries.append(
|
||||
TimestampEntry(
|
||||
timestamp=ts,
|
||||
entity_type="workflow_run_block",
|
||||
entity_id=block.workflow_run_block_id,
|
||||
field_name=field,
|
||||
label=block.label,
|
||||
status=block.status,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Fetch all tasks for this workflow run
|
||||
tasks = (
|
||||
await session.scalars(
|
||||
select(TaskModel).filter_by(workflow_run_id=workflow_run_id).order_by(TaskModel.created_at)
|
||||
)
|
||||
).all()
|
||||
|
||||
task_ids = []
|
||||
for task in tasks:
|
||||
task_ids.append(task.task_id)
|
||||
for field in ["created_at", "queued_at", "started_at", "finished_at"]:
|
||||
ts = getattr(task, field, None)
|
||||
if ts:
|
||||
entries.append(
|
||||
TimestampEntry(
|
||||
timestamp=ts,
|
||||
entity_type="task",
|
||||
entity_id=task.task_id,
|
||||
field_name=field,
|
||||
status=task.status,
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Fetch all steps for all tasks
|
||||
if task_ids:
|
||||
steps = (
|
||||
await session.scalars(
|
||||
select(StepModel).filter(StepModel.task_id.in_(task_ids)).order_by(StepModel.created_at)
|
||||
)
|
||||
).all()
|
||||
|
||||
for step in steps:
|
||||
for field in ["created_at", "finished_at"]:
|
||||
ts = getattr(step, field, None)
|
||||
if ts:
|
||||
entries.append(
|
||||
TimestampEntry(
|
||||
timestamp=ts,
|
||||
entity_type="step",
|
||||
entity_id=step.step_id,
|
||||
field_name=field,
|
||||
status=step.status,
|
||||
)
|
||||
)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def print_profile(entries: list[TimestampEntry]) -> None:
|
||||
"""Print the profiling results sorted by timestamp."""
|
||||
# Sort by timestamp
|
||||
sorted_entries = sorted(entries, key=lambda e: e.timestamp)
|
||||
|
||||
if not sorted_entries:
|
||||
print("No timestamps found.")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 120)
|
||||
print("WORKFLOW RUN PROFILE")
|
||||
print("=" * 120)
|
||||
print(f"{'Timestamp':<30} | {'Entity Type':<25} | {'Field':<12} | {'Entity ID'}")
|
||||
print("-" * 120)
|
||||
|
||||
first_ts = sorted_entries[0].timestamp
|
||||
for entry in sorted_entries:
|
||||
# Calculate relative time from first timestamp
|
||||
delta = entry.timestamp - first_ts
|
||||
delta_str = f"+{delta.total_seconds():>10.3f}s"
|
||||
|
||||
label_str = f" [{entry.label}]" if entry.label else ""
|
||||
status_str = f" ({entry.status})" if entry.status else ""
|
||||
|
||||
print(
|
||||
f"{entry.timestamp.isoformat():<30} {delta_str} | {entry.entity_type:<25}{label_str:<15} | {entry.field_name:<12} | {entry.entity_id[:36]}{status_str}"
|
||||
)
|
||||
|
||||
print("-" * 120)
|
||||
|
||||
# Print summary
|
||||
total_duration = sorted_entries[-1].timestamp - sorted_entries[0].timestamp
|
||||
print(f"\nTotal Duration: {total_duration.total_seconds():.3f} seconds")
|
||||
|
||||
# Count entities
|
||||
entity_counts: dict[str, int] = {}
|
||||
for entry in sorted_entries:
|
||||
if entry.field_name == "created_at":
|
||||
entity_counts[entry.entity_type] = entity_counts.get(entry.entity_type, 0) + 1
|
||||
|
||||
print("\nEntity Counts:")
|
||||
for entity_type, count in entity_counts.items():
|
||||
print(f" {entity_type}: {count}")
|
||||
|
||||
print("=" * 120 + "\n")
|
||||
|
||||
|
||||
async def profile_workflow_run(workflow_run_id: str) -> None:
|
||||
"""Main function to profile a workflow run."""
|
||||
print(f"Profiling workflow run: {workflow_run_id}")
|
||||
|
||||
entries = await collect_timestamps(workflow_run_id)
|
||||
print_profile(entries)
|
||||
|
||||
|
||||
def main(
|
||||
workflow_run_id: Annotated[str, typer.Argument(help="The workflow run ID to profile")],
|
||||
) -> None:
|
||||
"""Profile a workflow run by collecting and displaying all key timestamps."""
|
||||
start_forge_app()
|
||||
asyncio.run(profile_workflow_run(workflow_run_id))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
||||
Reference in New Issue
Block a user