diff --git a/scripts/profile_workflow_run.py b/scripts/profile_workflow_run.py index e90debf4..dd2fc7ab 100644 --- a/scripts/profile_workflow_run.py +++ b/scripts/profile_workflow_run.py @@ -3,6 +3,7 @@ Script to profile a workflow run by collecting and displaying all key timestamps Usage: python scripts/profile_workflow_run.py + python scripts/profile_workflow_run.py --include-actions """ import asyncio @@ -16,6 +17,7 @@ from sqlalchemy import select from skyvern.forge import app from skyvern.forge.forge_app_initializer import start_forge_app from skyvern.forge.sdk.db.models import ( + ActionModel, StepModel, TaskModel, WorkflowRunBlockModel, @@ -28,7 +30,7 @@ class TimestampEntry: """Represents a single timestamp entry for profiling.""" timestamp: datetime - entity_type: str # "workflow_run", "workflow_run_block", "task", "step" + entity_type: str # "workflow_run", "workflow_run_block", "task", "step", "action" entity_id: str field_name: str # "created_at", "started_at", "finished_at", etc. label: str | None = None # For blocks with labels @@ -40,7 +42,7 @@ class TimestampEntry: return f"{self.timestamp.isoformat()} | {self.entity_type:20}{label_str} | {self.field_name:12} | {self.entity_id}{status_str}" -async def collect_timestamps(workflow_run_id: str) -> list[TimestampEntry]: +async def collect_timestamps(workflow_run_id: str, include_actions: bool = False) -> list[TimestampEntry]: """Collect all timestamps from the workflow run and its children.""" entries: list[TimestampEntry] = [] @@ -136,6 +138,29 @@ async def collect_timestamps(workflow_run_id: str) -> list[TimestampEntry]: ) ) + # 5. Fetch all actions for all tasks (optional) + if include_actions and task_ids: + actions = ( + await session.scalars( + select(ActionModel).filter(ActionModel.task_id.in_(task_ids)).order_by(ActionModel.created_at) + ) + ).all() + + for action in actions: + for field in ["modified_at"]: + ts = getattr(action, field, None) + if ts: + entries.append( + TimestampEntry( + timestamp=ts, + entity_type="action", + entity_id=action.action_id, + field_name=field, + label=action.action_type, + status=action.status, + ) + ) + return entries @@ -186,20 +211,25 @@ def print_profile(entries: list[TimestampEntry]) -> None: print("=" * 120 + "\n") -async def profile_workflow_run(workflow_run_id: str) -> None: +async def profile_workflow_run(workflow_run_id: str, include_actions: bool = False) -> None: """Main function to profile a workflow run.""" print(f"Profiling workflow run: {workflow_run_id}") + if include_actions: + print("(including actions)") - entries = await collect_timestamps(workflow_run_id) + entries = await collect_timestamps(workflow_run_id, include_actions=include_actions) print_profile(entries) def main( workflow_run_id: Annotated[str, typer.Argument(help="The workflow run ID to profile")], + include_actions: Annotated[ + bool, typer.Option("--include-actions", "-a", help="Include action timestamps (can be noisy)") + ] = False, ) -> None: """Profile a workflow run by collecting and displaying all key timestamps.""" start_forge_app() - asyncio.run(profile_workflow_run(workflow_run_id)) + asyncio.run(profile_workflow_run(workflow_run_id, include_actions=include_actions)) if __name__ == "__main__":