deer-flow/backend/tests/test_subagent_token_collector.py
YuJitang 9892a7d468
fix: bucket subagent token usage into parent run totals (#2838)
* fix: bucket subagent token usage into RunRow.subagent_tokens

Add caller-bucketed token tracking to RunJournal so subagent and
middleware LLM calls are written to the correct RunRow columns instead
of all falling into lead_agent_tokens (default 0).

- RunJournal: accumulate _lead_agent_tokens / _subagent_tokens /
  _middleware_tokens in on_llm_end, deduped by langchain run_id.
  Add record_external_llm_usage_records() for external sources
  (respects track_token_usage flag). Return caller buckets from
  get_completion_data().
- SubagentTokenCollector: new lightweight callback handler that
  collects LLM usage within subagent execution.
- SubagentExecutor: wire collector into subagent run_config and sync
  records to SubagentResult on every chunk (timeout/cancel safe).
- SubagentResult: add token_usage_records and usage_reported fields.
- task_tool: report subagent usage to parent RunJournal on every
  terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including
  the CancelledError path, guarded against double-reporting.

No DB migration needed — RunRow columns already exist.

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: address token usage review feedback

* Address review follow-ups

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-10 22:47:30 +08:00

162 lines
6.5 KiB
Python

"""Tests for SubagentTokenCollector callback handler."""
from unittest.mock import MagicMock
from uuid import uuid4
from deerflow.subagents.token_collector import SubagentTokenCollector
def _make_llm_response(content="Hello", usage=None):
"""Create a mock LLM response with a message."""
msg = MagicMock()
msg.content = content
msg.usage_metadata = usage
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
return response
def _make_llm_response_from_usages(usages):
"""Create a mock LLM response with one generation per usage entry."""
generations = []
for usage in usages:
msg = MagicMock()
msg.content = "chunk"
msg.usage_metadata = usage
gen = MagicMock()
gen.message = msg
generations.append([gen])
response = MagicMock()
response.generations = generations
return response
class TestSubagentTokenCollector:
def test_collects_usage_from_response(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["caller"] == "subagent:test"
assert records[0]["input_tokens"] == 100
assert records[0]["output_tokens"] == 50
assert records[0]["total_tokens"] == 150
assert "source_run_id" in records[0]
def test_total_tokens_zero_uses_input_plus_output(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 200, "output_tokens": 100, "total_tokens": 0}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 300
def test_total_tokens_missing_uses_input_plus_output(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 30, "output_tokens": 20}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 50
def test_dedup_same_run_id(self):
collector = SubagentTokenCollector(caller="subagent:test")
run_id = uuid4()
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id)
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id)
records = collector.snapshot_records()
assert len(records) == 1
def test_no_usage_no_record(self):
collector = SubagentTokenCollector(caller="subagent:test")
collector.on_llm_end(_make_llm_response("Hi", usage=None), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_zero_usage_no_record(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_skips_empty_generation_and_records_later_usage(self):
collector = SubagentTokenCollector(caller="subagent:test")
response = _make_llm_response_from_usages(
[
None,
{"input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
]
)
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 30
def test_snapshot_returns_copy(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
snap1 = collector.snapshot_records()
snap2 = collector.snapshot_records()
assert snap1 == snap2
assert snap1 is not snap2
# Mutating snapshot does not affect internal records
snap1.append({"source_run_id": "fake"})
assert len(collector.snapshot_records()) == 1
def test_multiple_calls_accumulate(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4())
collector.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 2
def test_different_run_ids_accumulate_separately(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
collector.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4())
collector.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 2
assert records[0]["total_tokens"] == 15
assert records[1]["total_tokens"] == 30
def test_message_without_usage_metadata_skipped(self):
"""A response where message has no usage_metadata attribute must be skipped."""
collector = SubagentTokenCollector(caller="subagent:test")
msg = MagicMock(spec=[]) # object without usage_metadata
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_generation_without_message_skipped(self):
"""A generation without a message attribute must be skipped."""
collector = SubagentTokenCollector(caller="subagent:test")
gen = MagicMock(spec=[]) # object without message
response = MagicMock()
response.generations = [[gen]]
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0