fix anthropic llm stats (#2313)

This commit is contained in:
Shuchang Zheng
2025-05-08 14:30:00 -07:00
committed by GitHub
parent 6338a404a4
commit 1fbaf711b1
2 changed files with 62 additions and 25 deletions

View File

@@ -1430,6 +1430,7 @@ class ForgeAgent:
if not llm_caller.message_history:
llm_response = await llm_caller.call(
prompt=task.navigation_goal,
step=step,
screenshots=scraped_page.screenshots,
use_message_history=True,
tools=tools,
@@ -1440,6 +1441,7 @@ class ForgeAgent:
)
else:
llm_response = await llm_caller.call(
step=step,
screenshots=scraped_page.screenshots,
use_message_history=True,
tools=tools,

View File

@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse
from pydantic import BaseModel
from skyvern.config import settings
from skyvern.exceptions import SkyvernContextWindowExceededError
@@ -33,6 +34,14 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension,
LOG = structlog.get_logger()
class LLMCallStats(BaseModel):
input_tokens: int | None = None
output_tokens: int | None = None
reasoning_tokens: int | None = None
cached_tokens: int | None = None
llm_cost: float | None = None
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}
@@ -624,41 +633,27 @@ class LLMCaller:
)
if step or thought:
try:
llm_cost = litellm.completion_cost(completion_response=response)
except Exception as e:
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
call_stats = await self.get_call_stats(response)
if step:
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None,
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
incremental_cost=call_stats.llm_cost,
incremental_input_tokens=call_stats.input_tokens,
incremental_output_tokens=call_stats.output_tokens,
incremental_reasoning_tokens=call_stats.reasoning_tokens,
incremental_cached_tokens=call_stats.cached_tokens,
)
if thought:
await app.DATABASE.update_thought(
thought_id=thought.observer_thought_id,
organization_id=thought.organization_id,
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
output_token_count=completion_tokens if completion_tokens > 0 else None,
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
cached_token_count=cached_tokens if cached_tokens > 0 else None,
thought_cost=llm_cost,
input_token_count=call_stats.input_tokens,
output_token_count=call_stats.output_tokens,
reasoning_token_count=call_stats.reasoning_tokens,
cached_token_count=call_stats.cached_tokens,
thought_cost=call_stats.llm_cost,
)
# Track LLM API handler duration
duration_seconds = time.perf_counter() - start_time
@@ -757,6 +752,46 @@ class LLMCaller:
)
return response
async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats:
empty_call_stats = LLMCallStats()
if isinstance(response, AnthropicMessage):
usage = response.usage
input_token_cost = (3.0 / 1000000) * usage.input_tokens
output_token_cost = (15.0 / 1000000) * usage.output_tokens
cached_token_cost = (0.3 / 1000000) * usage.cache_read_input_tokens
llm_cost = input_token_cost + output_token_cost + cached_token_cost
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
cached_tokens=usage.cache_read_input_tokens,
reasoning_tokens=0,
)
elif isinstance(response, (ModelResponse, CustomStreamWrapper)):
try:
llm_cost = litellm.completion_cost(completion_response=response)
except Exception as e:
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
input_tokens = response.get("usage", {}).get("prompt_tokens", 0)
output_tokens = response.get("usage", {}).get("completion_tokens", 0)
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=input_tokens,
output_tokens=output_tokens,
cached_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
return empty_call_stats
class LLMCallerManager:
_llm_callers: dict[str, LLMCaller] = {}