From 171aef6bf7ec8c9fe7845bb6b084069b7774f15d Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Tue, 31 Dec 2024 11:24:09 -0800 Subject: [PATCH] add cost info to the workflow run repsonse (#1456) --- skyvern/forge/sdk/db/client.py | 15 +++++++++++++++ skyvern/forge/sdk/routes/agent_protocol.py | 1 + skyvern/forge/sdk/workflow/models/workflow.py | 2 ++ skyvern/forge/sdk/workflow/service.py | 18 +++++++++++++++++- 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index bffbb268..ab1bdcda 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -318,6 +318,21 @@ class AgentDB: LOG.error("UnexpectedError", exc_info=True) raise + async def get_steps_by_task_ids(self, task_ids: list[str], organization_id: str | None = None) -> list[Step]: + try: + async with self.Session() as session: + steps = ( + await session.scalars( + select(StepModel) + .filter(StepModel.task_id.in_(task_ids)) + .filter_by(organization_id=organization_id) + ) + ).all() + return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps] + except SQLAlchemyError: + LOG.error("SQLAlchemyError", exc_info=True) + raise + async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]: try: async with self.Session() as session: diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index b66a3e9b..75fdc79a 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -724,6 +724,7 @@ async def get_workflow_run( workflow_permanent_id=workflow_id, workflow_run_id=workflow_run_id, organization_id=current_org.organization_id, + include_cost=True, ) diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index 184de1ef..b4dbee20 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -134,3 +134,5 @@ class WorkflowRunStatusResponse(BaseModel): recording_url: str | None = None downloaded_file_urls: list[str] | None = None outputs: dict[str, Any] | None = None + total_steps: int | None = None + total_cost: float | None = None diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 1a2e9afe..fdb822eb 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -22,7 +22,7 @@ from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.db.enums import TaskType -from skyvern.forge.sdk.models import Step +from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunTimeline, WorkflowRunTimelineType @@ -741,6 +741,7 @@ class WorkflowService: self, workflow_run_id: str, organization_id: str, + include_cost: bool = False, ) -> WorkflowRunStatusResponse: workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) if workflow_run is None: @@ -751,6 +752,7 @@ class WorkflowService: workflow_permanent_id=workflow_permanent_id, workflow_run_id=workflow_run_id, organization_id=organization_id, + include_cost=include_cost, ) async def build_workflow_run_status_response( @@ -758,6 +760,7 @@ class WorkflowService: workflow_permanent_id: str, workflow_run_id: str, organization_id: str, + include_cost: bool = False, ) -> WorkflowRunStatusResponse: workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id, organization_id=organization_id) if workflow is None: @@ -824,6 +827,17 @@ class WorkflowService: if output_parameter_tuples: outputs = {output_parameter.key: output.value for output_parameter, output in output_parameter_tuples} + total_steps = None + total_cost = None + if include_cost: + workflow_run_steps = await app.DATABASE.get_steps_by_task_ids( + task_ids=[task.task_id for task in workflow_run_tasks], organization_id=organization_id + ) + total_steps = len(workflow_run_steps) + # TODO: This is a temporary cost calculation. We need to implement a more accurate cost calculation. + # successful steps are the ones that have a status of completed and the total count of unique step.order + successful_steps = set(step.order for step in workflow_run_steps if step.status == StepStatus.completed) + total_cost = 0.1 * len(successful_steps) return WorkflowRunStatusResponse( workflow_id=workflow.workflow_permanent_id, workflow_run_id=workflow_run_id, @@ -840,6 +854,8 @@ class WorkflowService: recording_url=recording_url, downloaded_file_urls=downloaded_file_urls, outputs=outputs, + total_steps=total_steps, + total_cost=total_cost, ) async def clean_up_workflow(