add observer thought token count and llm cost (#1595)

This commit is contained in:
Shuchang Zheng
2025-01-22 07:45:40 +08:00
committed by GitHub
parent 98a7d1ced5
commit 4578b6fe86
5 changed files with 93 additions and 18 deletions

View File

@@ -0,0 +1,35 @@
"""Add thought cost, input token count and output token count
Revision ID: 13e4af5c975c
Revises: 9adef4708ca8
Create Date: 2025-01-21 23:37:35.122761+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "13e4af5c975c"
down_revision: Union[str, None] = "9adef4708ca8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("observer_thoughts", sa.Column("input_token_count", sa.Integer(), nullable=True))
op.add_column("observer_thoughts", sa.Column("output_token_count", sa.Integer(), nullable=True))
op.add_column("observer_thoughts", sa.Column("thought_cost", sa.Numeric(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("observer_thoughts", "thought_cost")
op.drop_column("observer_thoughts", "output_token_count")
op.drop_column("observer_thoughts", "input_token_count")
# ### end Alembic commands ###

View File

@@ -146,7 +146,7 @@ class LLMAPIHandlerFactory:
observer_thought=observer_thought,
ai_suggestion=ai_suggestion,
)
if step:
if step or observer_thought:
try:
llm_cost = litellm.completion_cost(completion_response=response)
except Exception as e:
@@ -154,14 +154,24 @@ class LLMAPIHandlerFactory:
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
if step:
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
if observer_thought:
await app.DATABASE.update_observer_thought(
observer_thought_id=observer_thought.observer_thought_id,
organization_id=observer_thought.organization_id,
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
output_token_count=completion_tokens if completion_tokens > 0 else None,
thought_cost=llm_cost,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
@@ -292,7 +302,7 @@ class LLMAPIHandlerFactory:
ai_suggestion=ai_suggestion,
)
if step:
if step or observer_thought:
try:
llm_cost = litellm.completion_cost(completion_response=response)
except Exception as e:
@@ -300,14 +310,23 @@ class LLMAPIHandlerFactory:
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
if step:
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
incremental_input_tokens=prompt_tokens if prompt_tokens > 0 else None,
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
if observer_thought:
await app.DATABASE.update_observer_thought(
observer_thought_id=observer_thought.observer_thought_id,
organization_id=observer_thought.organization_id,
input_token_count=prompt_tokens if prompt_tokens > 0 else None,
output_token_count=completion_tokens if completion_tokens > 0 else None,
thought_cost=llm_cost,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),

View File

@@ -2039,6 +2039,9 @@ class AgentDB:
observer_thought_scenario: str | None = None,
observer_thought_type: str = ObserverThoughtType.plan,
output: dict[str, Any] | None = None,
input_token_count: int | None = None,
output_token_count: int | None = None,
thought_cost: float | None = None,
organization_id: str | None = None,
) -> ObserverThought:
async with self.Session() as session:
@@ -2055,6 +2058,9 @@ class AgentDB:
observer_thought_scenario=observer_thought_scenario,
observer_thought_type=observer_thought_type,
output=output,
input_token_count=input_token_count,
output_token_count=output_token_count,
thought_cost=thought_cost,
organization_id=organization_id,
)
session.add(new_observer_thought)
@@ -2073,6 +2079,9 @@ class AgentDB:
thought: str | None = None,
answer: str | None = None,
output: dict[str, Any] | None = None,
input_token_count: int | None = None,
output_token_count: int | None = None,
thought_cost: float | None = None,
organization_id: str | None = None,
) -> ObserverThought:
async with self.Session() as session:
@@ -2100,6 +2109,12 @@ class AgentDB:
observer_thought.answer = answer
if output:
observer_thought.output = output
if input_token_count:
observer_thought.input_token_count = input_token_count
if output_token_count:
observer_thought.output_token_count = output_token_count
if thought_cost:
observer_thought.thought_cost = thought_cost
await session.commit()
await session.refresh(observer_thought)
return ObserverThought.model_validate(observer_thought)

View File

@@ -574,6 +574,9 @@ class ObserverThoughtModel(Base):
observation = Column(String, nullable=True)
thought = Column(String, nullable=True)
answer = Column(String, nullable=True)
input_token_count = Column(Integer, nullable=True)
output_token_count = Column(Integer, nullable=True)
thought_cost = Column(Numeric, nullable=True)
observer_thought_type = Column(String, nullable=True, default=ObserverThoughtType.plan)
observer_thought_scenario = Column(String, nullable=True)

View File

@@ -85,6 +85,9 @@ class ObserverThought(BaseModel):
observer_thought_type: ObserverThoughtType | None = Field(alias="thought_type", default=ObserverThoughtType.plan)
observer_thought_scenario: ObserverThoughtScenario | None = Field(alias="thought_scenario", default=None)
output: dict[str, Any] | None = None
input_token_count: int | None = None
output_token_count: int | None = None
thought_cost: float | None = None
created_at: datetime
modified_at: datetime