mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-14 20:53:41 +00:00
fix(persistence): reuse token usage model grouping expression (#2910)
This commit is contained in:
parent
e9deb6c2f2
commit
2a1ac06bf4
@ -223,10 +223,11 @@ class RunRepository(RunStore):
|
||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||
_completed = RunRow.status.in_(("success", "error"))
|
||||
_thread = RunRow.thread_id == thread_id
|
||||
model_name = func.coalesce(RunRow.model_name, "unknown")
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunRow.model_name, "unknown").label("model"),
|
||||
model_name.label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
||||
@ -236,7 +237,7 @@ class RunRepository(RunStore):
|
||||
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(_thread, _completed)
|
||||
.group_by(func.coalesce(RunRow.model_name, "unknown"))
|
||||
.group_by(model_name)
|
||||
)
|
||||
|
||||
async with self._sf() as session:
|
||||
|
||||
@ -3,7 +3,10 @@
|
||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
@ -278,3 +281,48 @@ class TestRunRepository:
|
||||
assert row4["model_name"] is None
|
||||
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
|
||||
captured = []
|
||||
|
||||
class FakeResult:
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
class FakeSession:
|
||||
async def execute(self, stmt):
|
||||
captured.append(stmt)
|
||||
return FakeResult()
|
||||
|
||||
class FakeSessionContext:
|
||||
async def __aenter__(self):
|
||||
return FakeSession()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
repo = RunRepository(lambda: FakeSessionContext())
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("t1")
|
||||
assert agg == {
|
||||
"total_tokens": 0,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_runs": 0,
|
||||
"by_model": {},
|
||||
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
|
||||
}
|
||||
assert len(captured) == 1
|
||||
|
||||
stmt = captured[0]
|
||||
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
|
||||
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
|
||||
|
||||
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
|
||||
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
|
||||
|
||||
assert select_match is not None
|
||||
assert group_by_match is not None
|
||||
assert select_match.group(1) == group_by_match.group(1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user