From 04c6e55848e90e790d1d811176ed2c04ee039457 Mon Sep 17 00:00:00 2001 From: Kerem Yilmaz Date: Mon, 3 Jun 2024 15:55:34 -0700 Subject: [PATCH] Keep track of token counts in steps table (#412) --- skyvern/forge/sdk/api/llm/api_handler_factory.py | 8 ++++++++ skyvern/forge/sdk/db/client.py | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 0c5f68b1..99dea180 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -115,11 +115,15 @@ class LLMAPIHandlerFactory: data=response.model_dump_json(indent=2).encode("utf-8"), ) llm_cost = litellm.completion_cost(completion_response=response) + prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) + completion_tokens = response.get("usage", {}).get("completion_tokens", 0) await app.DATABASE.update_step( task_id=step.task_id, step_id=step.step_id, organization_id=step.organization_id, incremental_cost=llm_cost, + incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, + incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, ) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) if step: @@ -206,11 +210,15 @@ class LLMAPIHandlerFactory: data=response.model_dump_json(indent=2).encode("utf-8"), ) llm_cost = litellm.completion_cost(completion_response=response) + prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) + completion_tokens = response.get("usage", {}).get("completion_tokens", 0) await app.DATABASE.update_step( task_id=step.task_id, step_id=step.step_id, organization_id=step.organization_id, incremental_cost=llm_cost, + incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, + incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, ) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) if step: diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index b4880a75..52632fa1 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -293,6 +293,8 @@ class AgentDB: retry_index: int | None = None, organization_id: str | None = None, incremental_cost: float | None = None, + incremental_input_tokens: int | None = None, + incremental_output_tokens: int | None = None, ) -> Step: try: async with self.Session() as session: @@ -314,6 +316,10 @@ class AgentDB: step.retry_index = retry_index if incremental_cost is not None: step.step_cost = incremental_cost + float(step.step_cost or 0) + if incremental_input_tokens is not None: + step.input_token_count = incremental_input_tokens + (step.input_token_count or 0) + if incremental_output_tokens is not None: + step.output_token_count = incremental_output_tokens + (step.output_token_count or 0) await session.commit() updated_step = await self.get_step(task_id, step_id, organization_id)