track reasoning token and cached token (#1985)

This commit is contained in:
Shuchang Zheng
2025-03-20 16:42:57 -07:00
committed by GitHub
parent 185464f8ec
commit eb3eb4eede
9 changed files with 112 additions and 16 deletions

View File

@@ -163,12 +163,11 @@ class LLMAPIHandlerFactory:
LOG.exception("Failed to calculate LLM cost", error=str(e))
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
# TODO (suchintan): Properly support reasoning tokens
reasoning_tokens = response.get("usage", {}).get("reasoning_tokens", 0)
LOG.debug("Reasoning tokens", reasoning_tokens=reasoning_tokens)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0) + reasoning_tokens
reasoning_tokens = (
response.get("usage", {}).get("completion_tokens_details", {}).get("reasoning_tokens", 0)
)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
cached_tokens = response.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0)
if step:
await app.DATABASE.update_step(
@@ -178,6 +177,8 @@ class LLMAPIHandlerFactory:
incremental_cost=llm_cost,
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None,
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
)
if thought:
await app.DATABASE.update_thought(
@@ -186,6 +187,8 @@ class LLMAPIHandlerFactory:
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
output_token_count=completion_tokens if completion_tokens > 0 else None,
thought_cost=llm_cost,
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
cached_token_count=cached_tokens if cached_tokens > 0 else None,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact(
@@ -348,6 +351,10 @@ class LLMAPIHandlerFactory:
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
reasoning_tokens = (
response.get("usage", {}).get("completion_tokens_details", {}).get("reasoning_tokens", 0)
)
cached_tokens = response.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0)
if step:
await app.DATABASE.update_step(
task_id=step.task_id,
@@ -356,6 +363,8 @@ class LLMAPIHandlerFactory:
incremental_cost=llm_cost,
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
incremental_reasoning_tokens=reasoning_tokens if reasoning_tokens > 0 else None,
incremental_cached_tokens=cached_tokens if cached_tokens > 0 else None,
)
if thought:
await app.DATABASE.update_thought(
@@ -363,6 +372,8 @@ class LLMAPIHandlerFactory:
organization_id=thought.organization_id,
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
output_token_count=completion_tokens if completion_tokens > 0 else None,
reasoning_token_count=reasoning_tokens if reasoning_tokens > 0 else None,
cached_token_count=cached_tokens if cached_tokens > 0 else None,
thought_cost=llm_cost,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)

View File

@@ -492,6 +492,8 @@ class AgentDB:
incremental_cost: float | None = None,
incremental_input_tokens: int | None = None,
incremental_output_tokens: int | None = None,
incremental_reasoning_tokens: int | None = None,
incremental_cached_tokens: int | None = None,
) -> Step:
try:
async with self.Session() as session:
@@ -517,6 +519,10 @@ class AgentDB:
step.input_token_count = incremental_input_tokens + (step.input_token_count or 0)
if incremental_output_tokens is not None:
step.output_token_count = incremental_output_tokens + (step.output_token_count or 0)
if incremental_reasoning_tokens is not None:
step.reasoning_token_count = incremental_reasoning_tokens + (step.reasoning_token_count or 0)
if incremental_cached_tokens is not None:
step.cached_token_count = incremental_cached_tokens + (step.cached_token_count or 0)
await session.commit()
updated_step = await self.get_step(task_id, step_id, organization_id)
@@ -2290,6 +2296,8 @@ class AgentDB:
output: dict[str, Any] | None = None,
input_token_count: int | None = None,
output_token_count: int | None = None,
reasoning_token_count: int | None = None,
cached_token_count: int | None = None,
thought_cost: float | None = None,
organization_id: str | None = None,
) -> Thought:
@@ -2309,6 +2317,8 @@ class AgentDB:
output=output,
input_token_count=input_token_count,
output_token_count=output_token_count,
reasoning_token_count=reasoning_token_count,
cached_token_count=cached_token_count,
thought_cost=thought_cost,
organization_id=organization_id,
)
@@ -2330,6 +2340,8 @@ class AgentDB:
output: dict[str, Any] | None = None,
input_token_count: int | None = None,
output_token_count: int | None = None,
reasoning_token_count: int | None = None,
cached_token_count: int | None = None,
thought_cost: float | None = None,
organization_id: str | None = None,
) -> Thought:
@@ -2362,6 +2374,10 @@ class AgentDB:
thought_obj.input_token_count = input_token_count
if output_token_count:
thought_obj.output_token_count = output_token_count
if reasoning_token_count:
thought_obj.reasoning_token_count = reasoning_token_count
if cached_token_count:
thought_obj.cached_token_count = cached_token_count
if thought_cost:
thought_obj.thought_cost = thought_cost
await session.commit()

View File

@@ -117,6 +117,8 @@ class StepModel(Base):
)
input_token_count = Column(Integer, default=0)
output_token_count = Column(Integer, default=0)
reasoning_token_count = Column(Integer, default=0)
cached_token_count = Column(Integer, default=0)
step_cost = Column(Numeric, default=0)
@@ -612,6 +614,8 @@ class ThoughtModel(Base):
answer = Column(String, nullable=True)
input_token_count = Column(Integer, nullable=True)
output_token_count = Column(Integer, nullable=True)
reasoning_token_count = Column(Integer, nullable=True)
cached_token_count = Column(Integer, nullable=True)
thought_cost = Column(Numeric, nullable=True)
observer_thought_type = Column(String, nullable=True, default=ThoughtType.plan)

View File

@@ -108,6 +108,8 @@ def convert_to_step(step_model: StepModel, debug_enabled: bool = False) -> Step:
organization_id=step_model.organization_id,
input_token_count=step_model.input_token_count,
output_token_count=step_model.output_token_count,
reasoning_token_count=step_model.reasoning_token_count,
cached_token_count=step_model.cached_token_count,
step_cost=step_model.step_cost,
)

View File

@@ -52,6 +52,8 @@ class Step(BaseModel):
organization_id: str | None = None
input_token_count: int = 0
output_token_count: int = 0
reasoning_token_count: int = 0
cached_token_count: int = 0
step_cost: float = 0
def validate_update(

View File

@@ -92,6 +92,8 @@ class Thought(BaseModel):
output: dict[str, Any] | None = None
input_token_count: int | None = None
output_token_count: int | None = None
reasoning_token_count: int | None = None
cached_token_count: int | None = None
thought_cost: float | None = None
created_at: datetime