fix anthropic llm stats (#2313)

This commit is contained in:
Shuchang Zheng
2025-05-08 14:30:00 -07:00
committed by GitHub
parent 6338a404a4
commit 1fbaf711b1
2 changed files with 62 additions and 25 deletions

View File

@@ -1430,6 +1430,7 @@ class ForgeAgent:
if not llm_caller.message_history: if not llm_caller.message_history:
llm_response = await llm_caller.call( llm_response = await llm_caller.call(
prompt=task.navigation_goal, prompt=task.navigation_goal,
step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
use_message_history=True, use_message_history=True,
tools=tools, tools=tools,
@@ -1440,6 +1441,7 @@ class ForgeAgent:
) )
else: else:
llm_response = await llm_caller.call( llm_response = await llm_caller.call(
step=step,
screenshots=scraped_page.screenshots, screenshots=scraped_page.screenshots,
use_message_history=True, use_message_history=True,
tools=tools, tools=tools,

View File

@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
from jinja2 import Template from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse from litellm.utils import CustomStreamWrapper, ModelResponse
from pydantic import BaseModel
from skyvern.config import settings from skyvern.config import settings
from skyvern.exceptions import SkyvernContextWindowExceededError from skyvern.exceptions import SkyvernContextWindowExceededError
@@ -33,6 +34,14 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension,
LOG = structlog.get_logger() LOG = structlog.get_logger()
class LLMCallStats(BaseModel):
input_tokens: int | None = None
output_tokens: int | None = None
reasoning_tokens: int | None = None
cached_tokens: int | None = None
llm_cost: float | None = None
class LLMAPIHandlerFactory: class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {} _custom_handlers: dict[str, LLMAPIHandler] = {}
@@ -624,41 +633,27 @@ class LLMCaller:
) )
if step or thought: if step or thought:
try: call_stats = await self.get_call_stats(response)
llm_cost = litellm.completion_cost(completion_response=response)
except Exception as e:
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
if step: if step:
await app.DATABASE.update_step( await app.DATABASE.update_step(
task_id=step.task_id, task_id=step.task_id,
step_id=step.step_id, step_id=step.step_id,
organization_id=step.organization_id, organization_id=step.organization_id,
incremental_cost=llm_cost, incremental_cost=call_stats.llm_cost,
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None, incremental_input_tokens=call_stats.input_tokens,
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, incremental_output_tokens=call_stats.output_tokens,
incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None, incremental_reasoning_tokens=call_stats.reasoning_tokens,
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None, incremental_cached_tokens=call_stats.cached_tokens,
) )
if thought: if thought:
await app.DATABASE.update_thought( await app.DATABASE.update_thought(
thought_id=thought.observer_thought_id, thought_id=thought.observer_thought_id,
organization_id=thought.organization_id, organization_id=thought.organization_id,
input_token_count=prompt_tokens if prompt_tokens > 0 else None, input_token_count=call_stats.input_tokens,
output_token_count=completion_tokens if completion_tokens > 0 else None, output_token_count=call_stats.output_tokens,
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None, reasoning_token_count=call_stats.reasoning_tokens,
cached_token_count=cached_tokens if cached_tokens > 0 else None, cached_token_count=call_stats.cached_tokens,
thought_cost=llm_cost, thought_cost=call_stats.llm_cost,
) )
# Track LLM API handler duration # Track LLM API handler duration
duration_seconds = time.perf_counter() - start_time duration_seconds = time.perf_counter() - start_time
@@ -757,6 +752,46 @@ class LLMCaller:
) )
return response return response
async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats:
empty_call_stats = LLMCallStats()
if isinstance(response, AnthropicMessage):
usage = response.usage
input_token_cost = (3.0 / 1000000) * usage.input_tokens
output_token_cost = (15.0 / 1000000) * usage.output_tokens
cached_token_cost = (0.3 / 1000000) * usage.cache_read_input_tokens
llm_cost = input_token_cost + output_token_cost + cached_token_cost
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
cached_tokens=usage.cache_read_input_tokens,
reasoning_tokens=0,
)
elif isinstance(response, (ModelResponse, CustomStreamWrapper)):
try:
llm_cost = litellm.completion_cost(completion_response=response)
except Exception as e:
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
llm_cost = 0
input_tokens = response.get("usage", {}).get("prompt_tokens", 0)
output_tokens = response.get("usage", {}).get("completion_tokens", 0)
reasoning_tokens = 0
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
if completion_token_detail:
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
cached_tokens = 0
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
if cached_token_detail:
cached_tokens = cached_token_detail.cached_tokens or 0
return LLMCallStats(
llm_cost=llm_cost,
input_tokens=input_tokens,
output_tokens=output_tokens,
cached_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
return empty_call_stats
class LLMCallerManager: class LLMCallerManager:
_llm_callers: dict[str, LLMCaller] = {} _llm_callers: dict[str, LLMCaller] = {}