fix(persistence): reuse token usage model grouping expression (#2910)

This commit is contained in:
Eilen Shin 2026-05-13 15:49:34 +08:00 committed by GitHub
parent e9deb6c2f2
commit 2a1ac06bf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 2 deletions

View File

@ -223,10 +223,11 @@ class RunRepository(RunStore):
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
_thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown")
stmt = (
select(
func.coalesce(RunRow.model_name, "unknown").label("model"),
model_name.label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
@ -236,7 +237,7 @@ class RunRepository(RunStore):
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
)
.where(_thread, _completed)
.group_by(func.coalesce(RunRow.model_name, "unknown"))
.group_by(model_name)
)
async with self._sf() as session:

View File

@ -3,7 +3,10 @@
Uses a temp SQLite DB to test ORM-backed CRUD operations.
"""
import re
import pytest
from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository
@ -278,3 +281,48 @@ class TestRunRepository:
assert row4["model_name"] is None
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
captured = []
class FakeResult:
def all(self):
return []
class FakeSession:
async def execute(self, stmt):
captured.append(stmt)
return FakeResult()
class FakeSessionContext:
async def __aenter__(self):
return FakeSession()
async def __aexit__(self, exc_type, exc, tb):
return None
repo = RunRepository(lambda: FakeSessionContext())
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg == {
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_runs": 0,
"by_model": {},
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
}
assert len(captured) == 1
stmt = captured[0]
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
assert select_match is not None
assert group_by_match is not None
assert select_match.group(1) == group_by_match.group(1)