deer-flow/backend/tests/test_token_usage_middleware.py
YuJitang eab7ae3d62
feat: stream subagent token usage to header via terminal task events (#2882)
* feat: real-time subagent token usage display in header and per-turn

Backend:
- Persist subagent token usage to AIMessage.usage_metadata via
  TokenUsageMiddleware, so accumulateUsage() naturally includes
  subagent tokens without frontend state management
- Cache subagent usage by tool_call_id in task_tool, write back
  to the dispatching AIMessage on next model response
- Emit subagent token usage on all terminal task events
  (task_completed, task_failed, task_cancelled, task_timed_out)
- Report subagent usage to parent RunJournal for API totals
- Search backward from ToolMessage to find dispatching AIMessage
  for correct multi-tool-call attribution

Frontend:
- Remove subagentUsage state, custom event handling, and prop
  threading — subagent tokens are now embedded in message metadata
- Simplify selectHeaderTokenUsage (no subagentUsage parameter)
- Per-turn inline badges show turn-specific usage via message
  accumulation
- Remove isLoading guard from MessageTokenUsageList for dynamic
  updates during streaming

* fix: prevent header token double counting from baseline reset race

onFinish, onError, and thread-switch useEffect all reset
pendingUsageBaselineMessageIdsRef to an empty Set. If
thread.isLoading is still true on the next render, all messages
pass the getMessagesAfterBaseline filter and their tokens are
added to backendUsage (which already includes them), causing
the header to display up to 2× the actual token count.

Capture current message IDs instead of using an empty Set so
that getMessagesAfterBaseline correctly returns no pending
messages even if thread.isLoading lags behind the stream end.

* fix: write back subagent tokens for all concurrent task tool calls

TokenUsageMiddleware only processed messages[-2], so when a
single model response dispatched multiple task tool calls only
the last ToolMessage had its cached subagent usage written back
to the dispatch AIMessage.usage_metadata. Earlier tasks' usage
stayed in _subagent_usage_cache indefinitely (leak) and never
appeared in the per-turn inline token display.

Walk backward through all consecutive ToolMessages before the
new AIMessage, and accumulate updates targeting the same
dispatch message into one state update so overlapping writes
don't clobber each other.

* fix: clean up subagent usage cache entry on task cancellation

When a task_tool invocation is cancelled via CancelledError, any
cached subagent usage entry leaked because the TokenUsageMiddleware
writeback path never fires after cancellation. Pop the cache entry
before re-raising to prevent unbounded growth of the module-level
_subagent_usage_cache dict.

* fix: address token usage review feedback

* fix: handle missing config for subagent usage cache

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-13 23:52:19 +08:00

282 lines
10 KiB
Python

"""Tests for TokenUsageMiddleware attribution annotations."""
import importlib
import logging
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, ToolMessage
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
TokenUsageMiddleware,
)
def _make_runtime():
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"}
return runtime
class TestTokenUsageMiddleware:
def test_logs_cache_token_details(self, caplog):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="Here is the final answer.",
usage_metadata={
"input_tokens": 350,
"output_tokens": 240,
"total_tokens": 590,
"input_token_details": {
"audio": 10,
"cache_creation": 200,
"cache_read": 100,
},
"output_token_details": {
"audio": 10,
"reasoning": 200,
},
},
)
with caplog.at_level(
logging.INFO,
logger="deerflow.agents.middlewares.token_usage_middleware",
):
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
assert "LLM token usage: input=350 output=240 total=590" in caplog.text
assert "input_token_details={'audio': 10, 'cache_creation': 200, 'cache_read': 100}" in caplog.text
assert "output_token_details={'audio': 10, 'reasoning': 200}" in caplog.text
def test_logs_basic_tokens_when_no_detail_fields_in_usage_metadata(self, caplog):
"""When usage_metadata has only totals (no input_token_details), log just the counts."""
middleware = TokenUsageMiddleware()
message = AIMessage(
content="Here is the final answer.",
usage_metadata={
"input_tokens": 350,
"output_tokens": 240,
"total_tokens": 590,
},
)
with caplog.at_level(
logging.INFO,
logger="deerflow.agents.middlewares.token_usage_middleware",
):
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
assert "LLM token usage: input=350 output=240 total=590" in caplog.text
assert "input_token_details" not in caplog.text
def test_no_log_when_usage_metadata_is_missing(self, caplog):
"""When usage_metadata is absent, no token usage line is logged."""
middleware = TokenUsageMiddleware()
message = AIMessage(
content="Here is the final answer.",
response_metadata={
"usage": {
"input_tokens": 350,
"output_tokens": 240,
"total_tokens": 590,
}
},
)
with caplog.at_level(
logging.INFO,
logger="deerflow.agents.middlewares.token_usage_middleware",
):
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
assert "LLM token usage" not in caplog.text
def test_annotates_todo_updates_with_structured_actions(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "write_todos:1",
"name": "write_todos",
"args": {
"todos": [
{"content": "Inspect streaming path", "status": "completed"},
{"content": "Design token attribution schema", "status": "in_progress"},
]
},
}
],
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
)
state = {
"messages": [message],
"todos": [
{"content": "Inspect streaming path", "status": "in_progress"},
{"content": "Design token attribution schema", "status": "pending"},
],
}
result = middleware.after_model(state, _make_runtime())
assert result is not None
updated_message = result["messages"][0]
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "tool_batch"
assert attribution["shared_attribution"] is True
assert attribution["tool_call_ids"] == ["write_todos:1"]
assert attribution["actions"] == [
{
"kind": "todo_complete",
"content": "Inspect streaming path",
"tool_call_id": "write_todos:1",
},
{
"kind": "todo_start",
"content": "Design token attribution schema",
"tool_call_id": "write_todos:1",
},
]
def test_annotates_subagent_and_search_steps(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "task:1",
"name": "task",
"args": {
"description": "spec-coder patch message grouping",
"subagent_type": "general-purpose",
},
},
{
"id": "web_search:1",
"name": "web_search",
"args": {"query": "LangGraph useStream messages tuple"},
},
],
)
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "tool_batch"
assert attribution["shared_attribution"] is True
assert attribution["actions"] == [
{
"kind": "subagent",
"description": "spec-coder patch message grouping",
"subagent_type": "general-purpose",
"tool_call_id": "task:1",
},
{
"kind": "search",
"tool_name": "web_search",
"query": "LangGraph useStream messages tuple",
"tool_call_id": "web_search:1",
},
]
def test_marks_final_answer_when_no_tools(self):
middleware = TokenUsageMiddleware()
message = AIMessage(content="Here is the final answer.")
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "final_answer"
assert attribution["shared_attribution"] is False
assert attribution["actions"] == []
def test_annotates_removed_todos(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "write_todos:remove",
"name": "write_todos",
"args": {
"todos": [],
},
}
],
)
result = middleware.after_model(
{
"messages": [message],
"todos": [
{"content": "Archive obsolete plan", "status": "pending"},
],
},
_make_runtime(),
)
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "todo_update"
assert attribution["shared_attribution"] is False
assert attribution["actions"] == [
{
"kind": "todo_remove",
"content": "Archive obsolete plan",
"tool_call_id": "write_todos:remove",
}
]
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
middleware = TokenUsageMiddleware()
first_dispatch = AIMessage(
content="",
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
)
second_dispatch = AIMessage(
content="",
tool_calls=[
{"id": "task:second-a", "name": "task", "args": {}},
{"id": "task:second-b", "name": "task", "args": {}},
],
)
messages = [
first_dispatch,
ToolMessage(content="first", tool_call_id="task:first"),
second_dispatch,
ToolMessage(content="second-a", tool_call_id="task:second-a"),
ToolMessage(content="second-b", tool_call_id="task:second-b"),
AIMessage(content="done"),
]
cached_usage = {
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
}
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
monkeypatch.setattr(
task_tool_module,
"pop_cached_subagent_usage",
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
)
result = middleware.after_model({"messages": messages}, _make_runtime())
assert result is not None
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
assert len(usage_updates) == 1
updated = usage_updates[0]
assert updated.tool_calls == second_dispatch.tool_calls
assert updated.usage_metadata == {
"input_tokens": 30,
"output_tokens": 12,
"total_tokens": 42,
}