diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 8176a442..0460f08b 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -250,6 +250,28 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise + async def get_tasks_by_ids( + self, + task_ids: list[str], + organization_id: str | None = None, + ) -> list[Task]: + try: + async with self.Session() as session: + tasks = ( + await session.scalars( + select(TaskModel) + .filter(TaskModel.task_id.in_(task_ids)) + .filter_by(organization_id=organization_id) + ) + ).all() + return [convert_to_task(task, debug_enabled=self.debug_enabled) for task in tasks] + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + except Exception: + LOG.error("UnexpectedError", exc_info=True) + raise + async def get_step(self, task_id: str, step_id: str, organization_id: str | None = None) -> Step | None: try: async with self.Session() as session: @@ -1883,7 +1905,7 @@ class AgentDB: return ObserverThought.model_validate(observer_thought) return None - async def get_observer_cruise_thoughts( + async def get_observer_thoughts( self, observer_cruise_id: str, organization_id: str | None = None, @@ -2079,3 +2101,24 @@ class AgentDB: task = await self.get_task(task_id, organization_id=organization_id) return convert_to_workflow_run_block(workflow_run_block, task=task) raise NotFoundError(f"WorkflowRunBlock {workflow_run_block_id} not found") + + async def get_workflow_run_blocks( + self, + workflow_run_id: str, + organization_id: str | None = None, + ) -> list[WorkflowRunBlock]: + async with self.Session() as session: + workflow_run_blocks = ( + await session.scalars( + select(WorkflowRunBlockModel) + .filter_by(workflow_run_id=workflow_run_id) + .filter_by(organization_id=organization_id) + .order_by(WorkflowRunBlockModel.created_at) + ) + ).all() + tasks = await self.get_tasks_by_workflow_run_id(workflow_run_id) + tasks_dict = {task.task_id: task for task in tasks} + return [ + convert_to_workflow_run_block(workflow_run_block, task=tasks_dict.get(workflow_run_block.task_id)) + for workflow_run_block in workflow_run_blocks + ] diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index be3eac35..020d9712 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -33,7 +33,7 @@ from skyvern.forge.sdk.artifact.models import Artifact from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.permissions.permission_checker_factory import PermissionCheckerFactory from skyvern.forge.sdk.core.security import generate_skyvern_signature -from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType, TaskType +from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType from skyvern.forge.sdk.executor.factory import AsyncExecutorFactory from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.observers import CruiseRequest, ObserverCruise @@ -53,14 +53,13 @@ from skyvern.forge.sdk.schemas.tasks import ( TaskResponse, TaskStatus, ) -from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunEvent, WorkflowRunEventType +from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline from skyvern.forge.sdk.services import observer_service, org_auth_service from skyvern.forge.sdk.workflow.exceptions import ( FailedToCreateWorkflow, FailedToUpdateWorkflow, WorkflowParameterMissingRequiredValue, ) -from skyvern.forge.sdk.workflow.models.block import BlockType from skyvern.forge.sdk.workflow.models.workflow import ( RunWorkflowResponse, Workflow, @@ -727,88 +726,32 @@ async def get_workflow_run( @base_router.get( - "/workflows/{workflow_id}/runs/{workflow_run_id}/events", + "/workflows/{workflow_id}/runs/{workflow_run_id}/timeline", ) @base_router.get( - "/workflows/{workflow_id}/runs/{workflow_run_id}/events/", + "/workflows/{workflow_id}/runs/{workflow_run_id}/timeline/", ) -async def get_workflow_run_events( +async def get_workflow_run_timeline( workflow_id: str, workflow_run_id: str, observer_cruise_id: str | None = None, page: int = Query(1, ge=1), page_size: int = Query(20, ge=1), current_org: Organization = Depends(org_auth_service.get_current_org), -) -> list[WorkflowRunEvent]: - # get all the tasks for the workflow run - tasks = await app.DATABASE.get_tasks( - page, - page_size, +) -> list[WorkflowRunTimeline]: + # get all the workflow run blocks + workflow_run_block_timeline = await app.WORKFLOW_SERVICE.get_workflow_run_timeline( workflow_run_id=workflow_run_id, organization_id=current_org.organization_id, ) - workflow_run_events: list[WorkflowRunEvent] = [] - for task in tasks: - block_type = BlockType.TASK - if task.task_type == TaskType.general: - if not task.navigation_goal and task.data_extraction_goal: - block_type = BlockType.EXTRACTION - elif task.navigation_goal and not task.data_extraction_goal: - block_type = BlockType.NAVIGATION - elif task.task_type == TaskType.validation: - block_type = BlockType.VALIDATION - elif task.task_type == TaskType.action: - block_type = BlockType.ACTION - event = WorkflowRunEvent( - type=WorkflowRunEventType.block, - block=WorkflowRunBlock( - workflow_run_id=workflow_run_id, - block_type=block_type, - label=task.title, - title=task.title, - url=task.url, - status=task.status, - navigation_goal=task.navigation_goal, - data_extraction_goal=task.data_extraction_goal, - data_schema=task.extracted_information_schema, - terminate_criterion=task.terminate_criterion, - complete_criterion=task.complete_criterion, - created_at=task.created_at, - modified_at=task.modified_at, - ), - created_at=task.created_at, - modified_at=task.modified_at, - ) - workflow_run_events.append(event) - # get all the actions for all the tasks - actions = await app.DATABASE.get_tasks_actions( - [task.task_id for task in tasks], organization_id=current_org.organization_id - ) - for action in actions: - workflow_run_events.append( - WorkflowRunEvent( - type=WorkflowRunEventType.action, - action=action, - created_at=action.created_at or datetime.datetime.utcnow(), - modified_at=action.modified_at or datetime.datetime.utcnow(), - ) - ) - # get all the thoughts for the cruise if observer_cruise_id: - thoughts = await app.DATABASE.get_observer_cruise_thoughts( - observer_cruise_id, organization_id=current_org.organization_id + observer_thought_timeline = await observer_service.get_observer_thought_timelines( + observer_cruise_id=observer_cruise_id, + organization_id=current_org.organization_id, ) - for thought in thoughts: - workflow_run_events.append( - WorkflowRunEvent( - type=WorkflowRunEventType.thought, - thought=thought, - created_at=thought.created_at, - modified_at=thought.modified_at, - ) - ) - workflow_run_events.sort(key=lambda x: x.created_at) - return workflow_run_events + workflow_run_block_timeline.extend(observer_thought_timeline) + workflow_run_block_timeline.sort(key=lambda x: x.created_at) + return workflow_run_block_timeline @base_router.get( diff --git a/skyvern/forge/sdk/schemas/workflow_runs.py b/skyvern/forge/sdk/schemas/workflow_runs.py index a0b39f00..c5679bf3 100644 --- a/skyvern/forge/sdk/schemas/workflow_runs.py +++ b/skyvern/forge/sdk/schemas/workflow_runs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from enum import StrEnum from typing import Any @@ -10,7 +12,7 @@ from skyvern.webeye.actions.actions import Action class WorkflowRunBlock(BaseModel): - workflow_run_block_id: str = "placeholder" + workflow_run_block_id: str workflow_run_id: str parent_workflow_run_block_id: str | None = None block_type: BlockType @@ -26,20 +28,27 @@ class WorkflowRunBlock(BaseModel): data_schema: dict[str, Any] | list | str | None = None terminate_criterion: str | None = None complete_criterion: str | None = None + actions: list[Action] = [] created_at: datetime modified_at: datetime + # for loop block + loop_values: list[Any] | None = None -class WorkflowRunEventType(StrEnum): - action = "action" + # block inside a loop block + current_item: Any | None = None + current_index: int | None = None + + +class WorkflowRunTimelineType(StrEnum): thought = "thought" block = "block" -class WorkflowRunEvent(BaseModel): - type: WorkflowRunEventType - action: Action | None = None - thought: ObserverThought | None = None +class WorkflowRunTimeline(BaseModel): + type: WorkflowRunTimelineType block: WorkflowRunBlock | None = None + thought: ObserverThought | None = None + children: list[WorkflowRunTimeline] = [] created_at: datetime modified_at: datetime diff --git a/skyvern/forge/sdk/services/observer_service.py b/skyvern/forge/sdk/services/observer_service.py index ecd0f3bf..fdbc5473 100644 --- a/skyvern/forge/sdk/services/observer_service.py +++ b/skyvern/forge/sdk/services/observer_service.py @@ -15,6 +15,7 @@ from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverCruiseStatus, ObserverMetadata from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.tasks import ProxyLocation +from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunTimeline, WorkflowRunTimelineType from skyvern.forge.sdk.workflow.models.block import ( BlockResult, BlockStatus, @@ -776,3 +777,19 @@ def _generate_random_string(length: int = 5) -> str: # Use the current timestamp as the seed random.seed(os.urandom(16)) return "".join(random.choices(RANDOM_STRING_POOL, k=length)) + + +async def get_observer_thought_timelines( + observer_cruise_id: str, + organization_id: str | None = None, +) -> list[WorkflowRunTimeline]: + observer_thoughts = await app.DATABASE.get_observer_thoughts(observer_cruise_id, organization_id=organization_id) + return [ + WorkflowRunTimeline( + type=WorkflowRunTimelineType.thought, + thought=thought, + created_at=thought.created_at, + modified_at=thought.modified_at, + ) + for thought in observer_thoughts + ] diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index a851caa9..c3e59663 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -25,6 +25,7 @@ from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task +from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunTimeline, WorkflowRunTimelineType from skyvern.forge.sdk.workflow.exceptions import ( ContextParameterSourceNotDefined, InvalidWaitBlockTime, @@ -1602,3 +1603,48 @@ class WorkflowService: organization=organization, request=workflow_create_request, ) + + async def get_workflow_run_timeline( + self, + workflow_run_id: str, + organization_id: str | None = None, + ) -> list[WorkflowRunTimeline]: + """ + build the tree structure of the workflow run timeline + """ + workflow_run_blocks = await app.DATABASE.get_workflow_run_blocks( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + # get all the actions for all workflow run blocks + task_ids = [block.task_id for block in workflow_run_blocks if block.task_id] + task_id_to_block: dict[str, WorkflowRunBlock] = { + block.task_id: block for block in workflow_run_blocks if block.task_id + } + actions = await app.DATABASE.get_tasks_actions(task_ids=task_ids, organization_id=organization_id) + for action in actions: + if not action.task_id: + continue + task_block = task_id_to_block[action.task_id] + task_block.actions.append(action) + + result = [] + block_map: dict[str, WorkflowRunTimeline] = {} + while workflow_run_blocks: + block = workflow_run_blocks.pop(0) + workflow_run_timeline = WorkflowRunTimeline( + type=WorkflowRunTimelineType.block, + block=block, + created_at=block.created_at, + modified_at=block.modified_at, + ) + if block.parent_workflow_run_block_id: + if block.parent_workflow_run_block_id in block_map: + block_map[block.parent_workflow_run_block_id].children.append(workflow_run_timeline) + else: + # put the block back to the queue + workflow_run_blocks.append(block) + else: + result.append(workflow_run_timeline) + + return result