fix anthropic llm stats (#2313)
This commit is contained in:
@@ -1430,6 +1430,7 @@ class ForgeAgent:
|
||||
if not llm_caller.message_history:
|
||||
llm_response = await llm_caller.call(
|
||||
prompt=task.navigation_goal,
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
use_message_history=True,
|
||||
tools=tools,
|
||||
@@ -1440,6 +1441,7 @@ class ForgeAgent:
|
||||
)
|
||||
else:
|
||||
llm_response = await llm_caller.call(
|
||||
step=step,
|
||||
screenshots=scraped_page.screenshots,
|
||||
use_message_history=True,
|
||||
tools=tools,
|
||||
|
||||
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
|
||||
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
||||
from jinja2 import Template
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import SkyvernContextWindowExceededError
|
||||
@@ -33,6 +34,14 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension,
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
||||
class LLMCallStats(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
reasoning_tokens: int | None = None
|
||||
cached_tokens: int | None = None
|
||||
llm_cost: float | None = None
|
||||
|
||||
|
||||
class LLMAPIHandlerFactory:
|
||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||
|
||||
@@ -624,41 +633,27 @@ class LLMCaller:
|
||||
)
|
||||
|
||||
if step or thought:
|
||||
try:
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
except Exception as e:
|
||||
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||
llm_cost = 0
|
||||
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||
reasoning_tokens = 0
|
||||
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
|
||||
cached_tokens = 0
|
||||
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
|
||||
if cached_token_detail:
|
||||
cached_tokens = cached_token_detail.cached_tokens or 0
|
||||
call_stats = await self.get_call_stats(response)
|
||||
if step:
|
||||
await app.DATABASE.update_step(
|
||||
task_id=step.task_id,
|
||||
step_id=step.step_id,
|
||||
organization_id=step.organization_id,
|
||||
incremental_cost=llm_cost,
|
||||
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
|
||||
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
|
||||
incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None,
|
||||
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
|
||||
incremental_cost=call_stats.llm_cost,
|
||||
incremental_input_tokens=call_stats.input_tokens,
|
||||
incremental_output_tokens=call_stats.output_tokens,
|
||||
incremental_reasoning_tokens=call_stats.reasoning_tokens,
|
||||
incremental_cached_tokens=call_stats.cached_tokens,
|
||||
)
|
||||
if thought:
|
||||
await app.DATABASE.update_thought(
|
||||
thought_id=thought.observer_thought_id,
|
||||
organization_id=thought.organization_id,
|
||||
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
|
||||
output_token_count=completion_tokens if completion_tokens > 0 else None,
|
||||
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
|
||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||
thought_cost=llm_cost,
|
||||
input_token_count=call_stats.input_tokens,
|
||||
output_token_count=call_stats.output_tokens,
|
||||
reasoning_token_count=call_stats.reasoning_tokens,
|
||||
cached_token_count=call_stats.cached_tokens,
|
||||
thought_cost=call_stats.llm_cost,
|
||||
)
|
||||
# Track LLM API handler duration
|
||||
duration_seconds = time.perf_counter() - start_time
|
||||
@@ -757,6 +752,46 @@ class LLMCaller:
|
||||
)
|
||||
return response
|
||||
|
||||
async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats:
|
||||
empty_call_stats = LLMCallStats()
|
||||
if isinstance(response, AnthropicMessage):
|
||||
usage = response.usage
|
||||
input_token_cost = (3.0 / 1000000) * usage.input_tokens
|
||||
output_token_cost = (15.0 / 1000000) * usage.output_tokens
|
||||
cached_token_cost = (0.3 / 1000000) * usage.cache_read_input_tokens
|
||||
llm_cost = input_token_cost + output_token_cost + cached_token_cost
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=usage.input_tokens,
|
||||
output_tokens=usage.output_tokens,
|
||||
cached_tokens=usage.cache_read_input_tokens,
|
||||
reasoning_tokens=0,
|
||||
)
|
||||
elif isinstance(response, (ModelResponse, CustomStreamWrapper)):
|
||||
try:
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
except Exception as e:
|
||||
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||
llm_cost = 0
|
||||
input_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||
output_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||
reasoning_tokens = 0
|
||||
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
|
||||
if completion_token_detail:
|
||||
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
|
||||
cached_tokens = 0
|
||||
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
|
||||
if cached_token_detail:
|
||||
cached_tokens = cached_token_detail.cached_tokens or 0
|
||||
return LLMCallStats(
|
||||
llm_cost=llm_cost,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
return empty_call_stats
|
||||
|
||||
|
||||
class LLMCallerManager:
|
||||
_llm_callers: dict[str, LLMCaller] = {}
|
||||
|
||||
Reference in New Issue
Block a user