From 1fbaf711b136c2f1f1f97ee919555fff1f1c6902 Mon Sep 17 00:00:00 2001 From: Shuchang Zheng Date: Thu, 8 May 2025 14:30:00 -0700 Subject: [PATCH] fix anthropic llm stats (#2313) --- skyvern/forge/agent.py | 2 + .../forge/sdk/api/llm/api_handler_factory.py | 85 +++++++++++++------ 2 files changed, 62 insertions(+), 25 deletions(-) diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 01e53ff3..3510e6e4 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -1430,6 +1430,7 @@ class ForgeAgent: if not llm_caller.message_history: llm_response = await llm_caller.call( prompt=task.navigation_goal, + step=step, screenshots=scraped_page.screenshots, use_message_history=True, tools=tools, @@ -1440,6 +1441,7 @@ class ForgeAgent: ) else: llm_response = await llm_caller.call( + step=step, screenshots=scraped_page.screenshots, use_message_history=True, tools=tools, diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 6ef25dfb..13863091 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage from jinja2 import Template from litellm.utils import CustomStreamWrapper, ModelResponse +from pydantic import BaseModel from skyvern.config import settings from skyvern.exceptions import SkyvernContextWindowExceededError @@ -33,6 +34,14 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension, LOG = structlog.get_logger() +class LLMCallStats(BaseModel): + input_tokens: int | None = None + output_tokens: int | None = None + reasoning_tokens: int | None = None + cached_tokens: int | None = None + llm_cost: float | None = None + + class LLMAPIHandlerFactory: _custom_handlers: dict[str, LLMAPIHandler] = {} @@ -624,41 +633,27 @@ class LLMCaller: ) if step or thought: - try: - llm_cost = litellm.completion_cost(completion_response=response) - except Exception as e: - LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) - llm_cost = 0 - prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = response.get("usage", {}).get("completion_tokens", 0) - reasoning_tokens = 0 - completion_token_detail = response.get("usage", {}).get("completion_tokens_details") - if completion_token_detail: - reasoning_tokens = completion_token_detail.reasoning_tokens or 0 - cached_tokens = 0 - cached_token_detail = response.get("usage", {}).get("prompt_tokens_details") - if cached_token_detail: - cached_tokens = cached_token_detail.cached_tokens or 0 + call_stats = await self.get_call_stats(response) if step: await app.DATABASE.update_step( task_id=step.task_id, step_id=step.step_id, organization_id=step.organization_id, - incremental_cost=llm_cost, - incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, - incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, - incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, - incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, + incremental_cost=call_stats.llm_cost, + incremental_input_tokens=call_stats.input_tokens, + incremental_output_tokens=call_stats.output_tokens, + incremental_reasoning_tokens=call_stats.reasoning_tokens, + incremental_cached_tokens=call_stats.cached_tokens, ) if thought: await app.DATABASE.update_thought( thought_id=thought.observer_thought_id, organization_id=thought.organization_id, - input_token_count=prompt_tokens if prompt_tokens > 0 else None, - output_token_count=completion_tokens if completion_tokens > 0 else None, - reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, - cached_token_count=cached_tokens if cached_tokens > 0 else None, - thought_cost=llm_cost, + input_token_count=call_stats.input_tokens, + output_token_count=call_stats.output_tokens, + reasoning_token_count=call_stats.reasoning_tokens, + cached_token_count=call_stats.cached_tokens, + thought_cost=call_stats.llm_cost, ) # Track LLM API handler duration duration_seconds = time.perf_counter() - start_time @@ -757,6 +752,46 @@ class LLMCaller: ) return response + async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats: + empty_call_stats = LLMCallStats() + if isinstance(response, AnthropicMessage): + usage = response.usage + input_token_cost = (3.0 / 1000000) * usage.input_tokens + output_token_cost = (15.0 / 1000000) * usage.output_tokens + cached_token_cost = (0.3 / 1000000) * usage.cache_read_input_tokens + llm_cost = input_token_cost + output_token_cost + cached_token_cost + return LLMCallStats( + llm_cost=llm_cost, + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + cached_tokens=usage.cache_read_input_tokens, + reasoning_tokens=0, + ) + elif isinstance(response, (ModelResponse, CustomStreamWrapper)): + try: + llm_cost = litellm.completion_cost(completion_response=response) + except Exception as e: + LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True) + llm_cost = 0 + input_tokens = response.get("usage", {}).get("prompt_tokens", 0) + output_tokens = response.get("usage", {}).get("completion_tokens", 0) + reasoning_tokens = 0 + completion_token_detail = response.get("usage", {}).get("completion_tokens_details") + if completion_token_detail: + reasoning_tokens = completion_token_detail.reasoning_tokens or 0 + cached_tokens = 0 + cached_token_detail = response.get("usage", {}).get("prompt_tokens_details") + if cached_token_detail: + cached_tokens = cached_token_detail.cached_tokens or 0 + return LLMCallStats( + llm_cost=llm_cost, + input_tokens=input_tokens, + output_tokens=output_tokens, + cached_tokens=cached_tokens, + reasoning_tokens=reasoning_tokens, + ) + return empty_call_stats + class LLMCallerManager: _llm_callers: dict[str, LLMCaller] = {}