fix anthropic llm stats (#2313)
This commit is contained in:
@@ -1430,6 +1430,7 @@ class ForgeAgent:
|
|||||||
if not llm_caller.message_history:
|
if not llm_caller.message_history:
|
||||||
llm_response = await llm_caller.call(
|
llm_response = await llm_caller.call(
|
||||||
prompt=task.navigation_goal,
|
prompt=task.navigation_goal,
|
||||||
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
use_message_history=True,
|
use_message_history=True,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -1440,6 +1441,7 @@ class ForgeAgent:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
llm_response = await llm_caller.call(
|
llm_response = await llm_caller.call(
|
||||||
|
step=step,
|
||||||
screenshots=scraped_page.screenshots,
|
screenshots=scraped_page.screenshots,
|
||||||
use_message_history=True,
|
use_message_history=True,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from anthropic import NOT_GIVEN
|
|||||||
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
from anthropic.types.beta.beta_message import BetaMessage as AnthropicMessage
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.exceptions import SkyvernContextWindowExceededError
|
from skyvern.exceptions import SkyvernContextWindowExceededError
|
||||||
@@ -33,6 +34,14 @@ from skyvern.utils.image_resizer import Resolution, get_resize_target_dimension,
|
|||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCallStats(BaseModel):
|
||||||
|
input_tokens: int | None = None
|
||||||
|
output_tokens: int | None = None
|
||||||
|
reasoning_tokens: int | None = None
|
||||||
|
cached_tokens: int | None = None
|
||||||
|
llm_cost: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMAPIHandlerFactory:
|
class LLMAPIHandlerFactory:
|
||||||
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
_custom_handlers: dict[str, LLMAPIHandler] = {}
|
||||||
|
|
||||||
@@ -624,41 +633,27 @@ class LLMCaller:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if step or thought:
|
if step or thought:
|
||||||
try:
|
call_stats = await self.get_call_stats(response)
|
||||||
llm_cost = litellm.completion_cost(completion_response=response)
|
|
||||||
except Exception as e:
|
|
||||||
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
|
||||||
llm_cost = 0
|
|
||||||
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
|
||||||
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
|
||||||
reasoning_tokens = 0
|
|
||||||
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
|
|
||||||
if completion_token_detail:
|
|
||||||
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
|
|
||||||
cached_tokens = 0
|
|
||||||
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
|
|
||||||
if cached_token_detail:
|
|
||||||
cached_tokens = cached_token_detail.cached_tokens or 0
|
|
||||||
if step:
|
if step:
|
||||||
await app.DATABASE.update_step(
|
await app.DATABASE.update_step(
|
||||||
task_id=step.task_id,
|
task_id=step.task_id,
|
||||||
step_id=step.step_id,
|
step_id=step.step_id,
|
||||||
organization_id=step.organization_id,
|
organization_id=step.organization_id,
|
||||||
incremental_cost=llm_cost,
|
incremental_cost=call_stats.llm_cost,
|
||||||
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
|
incremental_input_tokens=call_stats.input_tokens,
|
||||||
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
|
incremental_output_tokens=call_stats.output_tokens,
|
||||||
incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None,
|
incremental_reasoning_tokens=call_stats.reasoning_tokens,
|
||||||
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
|
incremental_cached_tokens=call_stats.cached_tokens,
|
||||||
)
|
)
|
||||||
if thought:
|
if thought:
|
||||||
await app.DATABASE.update_thought(
|
await app.DATABASE.update_thought(
|
||||||
thought_id=thought.observer_thought_id,
|
thought_id=thought.observer_thought_id,
|
||||||
organization_id=thought.organization_id,
|
organization_id=thought.organization_id,
|
||||||
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
|
input_token_count=call_stats.input_tokens,
|
||||||
output_token_count=completion_tokens if completion_tokens > 0 else None,
|
output_token_count=call_stats.output_tokens,
|
||||||
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
|
reasoning_token_count=call_stats.reasoning_tokens,
|
||||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
cached_token_count=call_stats.cached_tokens,
|
||||||
thought_cost=llm_cost,
|
thought_cost=call_stats.llm_cost,
|
||||||
)
|
)
|
||||||
# Track LLM API handler duration
|
# Track LLM API handler duration
|
||||||
duration_seconds = time.perf_counter() - start_time
|
duration_seconds = time.perf_counter() - start_time
|
||||||
@@ -757,6 +752,46 @@ class LLMCaller:
|
|||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def get_call_stats(self, response: ModelResponse | CustomStreamWrapper | AnthropicMessage) -> LLMCallStats:
|
||||||
|
empty_call_stats = LLMCallStats()
|
||||||
|
if isinstance(response, AnthropicMessage):
|
||||||
|
usage = response.usage
|
||||||
|
input_token_cost = (3.0 / 1000000) * usage.input_tokens
|
||||||
|
output_token_cost = (15.0 / 1000000) * usage.output_tokens
|
||||||
|
cached_token_cost = (0.3 / 1000000) * usage.cache_read_input_tokens
|
||||||
|
llm_cost = input_token_cost + output_token_cost + cached_token_cost
|
||||||
|
return LLMCallStats(
|
||||||
|
llm_cost=llm_cost,
|
||||||
|
input_tokens=usage.input_tokens,
|
||||||
|
output_tokens=usage.output_tokens,
|
||||||
|
cached_tokens=usage.cache_read_input_tokens,
|
||||||
|
reasoning_tokens=0,
|
||||||
|
)
|
||||||
|
elif isinstance(response, (ModelResponse, CustomStreamWrapper)):
|
||||||
|
try:
|
||||||
|
llm_cost = litellm.completion_cost(completion_response=response)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.debug("Failed to calculate LLM cost", error=str(e), exc_info=True)
|
||||||
|
llm_cost = 0
|
||||||
|
input_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||||
|
output_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||||
|
reasoning_tokens = 0
|
||||||
|
completion_token_detail = response.get("usage", {}).get("completion_tokens_details")
|
||||||
|
if completion_token_detail:
|
||||||
|
reasoning_tokens = completion_token_detail.reasoning_tokens or 0
|
||||||
|
cached_tokens = 0
|
||||||
|
cached_token_detail = response.get("usage", {}).get("prompt_tokens_details")
|
||||||
|
if cached_token_detail:
|
||||||
|
cached_tokens = cached_token_detail.cached_tokens or 0
|
||||||
|
return LLMCallStats(
|
||||||
|
llm_cost=llm_cost,
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
cached_tokens=cached_tokens,
|
||||||
|
reasoning_tokens=reasoning_tokens,
|
||||||
|
)
|
||||||
|
return empty_call_stats
|
||||||
|
|
||||||
|
|
||||||
class LLMCallerManager:
|
class LLMCallerManager:
|
||||||
_llm_callers: dict[str, LLMCaller] = {}
|
_llm_callers: dict[str, LLMCaller] = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user