diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 430fbe4f6..5331451e3 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -223,10 +223,11 @@ class RunRepository(RunStore): """Aggregate token usage via a single SQL GROUP BY query.""" _completed = RunRow.status.in_(("success", "error")) _thread = RunRow.thread_id == thread_id + model_name = func.coalesce(RunRow.model_name, "unknown") stmt = ( select( - func.coalesce(RunRow.model_name, "unknown").label("model"), + model_name.label("model"), func.count().label("runs"), func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"), func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"), @@ -236,7 +237,7 @@ class RunRepository(RunStore): func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"), ) .where(_thread, _completed) - .group_by(func.coalesce(RunRow.model_name, "unknown")) + .group_by(model_name) ) async with self._sf() as session: diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 6fd534829..5e230e790 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -3,7 +3,10 @@ Uses a temp SQLite DB to test ORM-backed CRUD operations. """ +import re + import pytest +from sqlalchemy.dialects import postgresql from deerflow.persistence.run import RunRepository @@ -278,3 +281,48 @@ class TestRunRepository: assert row4["model_name"] is None await _cleanup() + + @pytest.mark.anyio + async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self): + captured = [] + + class FakeResult: + def all(self): + return [] + + class FakeSession: + async def execute(self, stmt): + captured.append(stmt) + return FakeResult() + + class FakeSessionContext: + async def __aenter__(self): + return FakeSession() + + async def __aexit__(self, exc_type, exc, tb): + return None + + repo = RunRepository(lambda: FakeSessionContext()) + + agg = await repo.aggregate_tokens_by_thread("t1") + assert agg == { + "total_tokens": 0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_runs": 0, + "by_model": {}, + "by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0}, + } + assert len(captured) == 1 + + stmt = captured[0] + compiled_sql = str(stmt.compile(dialect=postgresql.dialect())) + select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1) + model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)" + + select_match = re.search(model_expr_pattern + r" AS model", select_sql) + group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip()) + + assert select_match is not None + assert group_by_match is not None + assert select_match.group(1) == group_by_match.group(1)