mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-14 12:43:45 +00:00
* feat: real-time subagent token usage display in header and per-turn Backend: - Persist subagent token usage to AIMessage.usage_metadata via TokenUsageMiddleware, so accumulateUsage() naturally includes subagent tokens without frontend state management - Cache subagent usage by tool_call_id in task_tool, write back to the dispatching AIMessage on next model response - Emit subagent token usage on all terminal task events (task_completed, task_failed, task_cancelled, task_timed_out) - Report subagent usage to parent RunJournal for API totals - Search backward from ToolMessage to find dispatching AIMessage for correct multi-tool-call attribution Frontend: - Remove subagentUsage state, custom event handling, and prop threading — subagent tokens are now embedded in message metadata - Simplify selectHeaderTokenUsage (no subagentUsage parameter) - Per-turn inline badges show turn-specific usage via message accumulation - Remove isLoading guard from MessageTokenUsageList for dynamic updates during streaming * fix: prevent header token double counting from baseline reset race onFinish, onError, and thread-switch useEffect all reset pendingUsageBaselineMessageIdsRef to an empty Set. If thread.isLoading is still true on the next render, all messages pass the getMessagesAfterBaseline filter and their tokens are added to backendUsage (which already includes them), causing the header to display up to 2× the actual token count. Capture current message IDs instead of using an empty Set so that getMessagesAfterBaseline correctly returns no pending messages even if thread.isLoading lags behind the stream end. * fix: write back subagent tokens for all concurrent task tool calls TokenUsageMiddleware only processed messages[-2], so when a single model response dispatched multiple task tool calls only the last ToolMessage had its cached subagent usage written back to the dispatch AIMessage.usage_metadata. Earlier tasks' usage stayed in _subagent_usage_cache indefinitely (leak) and never appeared in the per-turn inline token display. Walk backward through all consecutive ToolMessages before the new AIMessage, and accumulate updates targeting the same dispatch message into one state update so overlapping writes don't clobber each other. * fix: clean up subagent usage cache entry on task cancellation When a task_tool invocation is cancelled via CancelledError, any cached subagent usage entry leaked because the TokenUsageMiddleware writeback path never fires after cancellation. Pop the cache entry before re-raising to prevent unbounded growth of the module-level _subagent_usage_cache dict. * fix: address token usage review feedback * fix: handle missing config for subagent usage cache --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
282 lines
10 KiB
Python
282 lines
10 KiB
Python
"""Tests for TokenUsageMiddleware attribution annotations."""
|
|
|
|
import importlib
|
|
import logging
|
|
from unittest.mock import MagicMock
|
|
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
|
|
|
from deerflow.agents.middlewares.token_usage_middleware import (
|
|
TOKEN_USAGE_ATTRIBUTION_KEY,
|
|
TokenUsageMiddleware,
|
|
)
|
|
|
|
|
|
def _make_runtime():
|
|
runtime = MagicMock()
|
|
runtime.context = {"thread_id": "test-thread"}
|
|
return runtime
|
|
|
|
|
|
class TestTokenUsageMiddleware:
|
|
def test_logs_cache_token_details(self, caplog):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="Here is the final answer.",
|
|
usage_metadata={
|
|
"input_tokens": 350,
|
|
"output_tokens": 240,
|
|
"total_tokens": 590,
|
|
"input_token_details": {
|
|
"audio": 10,
|
|
"cache_creation": 200,
|
|
"cache_read": 100,
|
|
},
|
|
"output_token_details": {
|
|
"audio": 10,
|
|
"reasoning": 200,
|
|
},
|
|
},
|
|
)
|
|
|
|
with caplog.at_level(
|
|
logging.INFO,
|
|
logger="deerflow.agents.middlewares.token_usage_middleware",
|
|
):
|
|
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
|
|
|
assert result is not None
|
|
assert "LLM token usage: input=350 output=240 total=590" in caplog.text
|
|
assert "input_token_details={'audio': 10, 'cache_creation': 200, 'cache_read': 100}" in caplog.text
|
|
assert "output_token_details={'audio': 10, 'reasoning': 200}" in caplog.text
|
|
|
|
def test_logs_basic_tokens_when_no_detail_fields_in_usage_metadata(self, caplog):
|
|
"""When usage_metadata has only totals (no input_token_details), log just the counts."""
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="Here is the final answer.",
|
|
usage_metadata={
|
|
"input_tokens": 350,
|
|
"output_tokens": 240,
|
|
"total_tokens": 590,
|
|
},
|
|
)
|
|
|
|
with caplog.at_level(
|
|
logging.INFO,
|
|
logger="deerflow.agents.middlewares.token_usage_middleware",
|
|
):
|
|
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
|
|
|
assert result is not None
|
|
assert "LLM token usage: input=350 output=240 total=590" in caplog.text
|
|
assert "input_token_details" not in caplog.text
|
|
|
|
def test_no_log_when_usage_metadata_is_missing(self, caplog):
|
|
"""When usage_metadata is absent, no token usage line is logged."""
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="Here is the final answer.",
|
|
response_metadata={
|
|
"usage": {
|
|
"input_tokens": 350,
|
|
"output_tokens": 240,
|
|
"total_tokens": 590,
|
|
}
|
|
},
|
|
)
|
|
|
|
with caplog.at_level(
|
|
logging.INFO,
|
|
logger="deerflow.agents.middlewares.token_usage_middleware",
|
|
):
|
|
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
|
|
|
assert result is not None
|
|
assert "LLM token usage" not in caplog.text
|
|
|
|
def test_annotates_todo_updates_with_structured_actions(self):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "write_todos:1",
|
|
"name": "write_todos",
|
|
"args": {
|
|
"todos": [
|
|
{"content": "Inspect streaming path", "status": "completed"},
|
|
{"content": "Design token attribution schema", "status": "in_progress"},
|
|
]
|
|
},
|
|
}
|
|
],
|
|
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
|
|
)
|
|
|
|
state = {
|
|
"messages": [message],
|
|
"todos": [
|
|
{"content": "Inspect streaming path", "status": "in_progress"},
|
|
{"content": "Design token attribution schema", "status": "pending"},
|
|
],
|
|
}
|
|
|
|
result = middleware.after_model(state, _make_runtime())
|
|
|
|
assert result is not None
|
|
updated_message = result["messages"][0]
|
|
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
|
assert attribution["kind"] == "tool_batch"
|
|
assert attribution["shared_attribution"] is True
|
|
assert attribution["tool_call_ids"] == ["write_todos:1"]
|
|
assert attribution["actions"] == [
|
|
{
|
|
"kind": "todo_complete",
|
|
"content": "Inspect streaming path",
|
|
"tool_call_id": "write_todos:1",
|
|
},
|
|
{
|
|
"kind": "todo_start",
|
|
"content": "Design token attribution schema",
|
|
"tool_call_id": "write_todos:1",
|
|
},
|
|
]
|
|
|
|
def test_annotates_subagent_and_search_steps(self):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "task:1",
|
|
"name": "task",
|
|
"args": {
|
|
"description": "spec-coder patch message grouping",
|
|
"subagent_type": "general-purpose",
|
|
},
|
|
},
|
|
{
|
|
"id": "web_search:1",
|
|
"name": "web_search",
|
|
"args": {"query": "LangGraph useStream messages tuple"},
|
|
},
|
|
],
|
|
)
|
|
|
|
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
|
|
|
assert result is not None
|
|
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
|
assert attribution["kind"] == "tool_batch"
|
|
assert attribution["shared_attribution"] is True
|
|
assert attribution["actions"] == [
|
|
{
|
|
"kind": "subagent",
|
|
"description": "spec-coder patch message grouping",
|
|
"subagent_type": "general-purpose",
|
|
"tool_call_id": "task:1",
|
|
},
|
|
{
|
|
"kind": "search",
|
|
"tool_name": "web_search",
|
|
"query": "LangGraph useStream messages tuple",
|
|
"tool_call_id": "web_search:1",
|
|
},
|
|
]
|
|
|
|
def test_marks_final_answer_when_no_tools(self):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(content="Here is the final answer.")
|
|
|
|
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
|
|
|
assert result is not None
|
|
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
|
assert attribution["kind"] == "final_answer"
|
|
assert attribution["shared_attribution"] is False
|
|
assert attribution["actions"] == []
|
|
|
|
def test_annotates_removed_todos(self):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "write_todos:remove",
|
|
"name": "write_todos",
|
|
"args": {
|
|
"todos": [],
|
|
},
|
|
}
|
|
],
|
|
)
|
|
|
|
result = middleware.after_model(
|
|
{
|
|
"messages": [message],
|
|
"todos": [
|
|
{"content": "Archive obsolete plan", "status": "pending"},
|
|
],
|
|
},
|
|
_make_runtime(),
|
|
)
|
|
|
|
assert result is not None
|
|
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
|
assert attribution["kind"] == "todo_update"
|
|
assert attribution["shared_attribution"] is False
|
|
assert attribution["actions"] == [
|
|
{
|
|
"kind": "todo_remove",
|
|
"content": "Archive obsolete plan",
|
|
"tool_call_id": "write_todos:remove",
|
|
}
|
|
]
|
|
|
|
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
|
|
middleware = TokenUsageMiddleware()
|
|
first_dispatch = AIMessage(
|
|
content="",
|
|
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
|
|
)
|
|
second_dispatch = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{"id": "task:second-a", "name": "task", "args": {}},
|
|
{"id": "task:second-b", "name": "task", "args": {}},
|
|
],
|
|
)
|
|
messages = [
|
|
first_dispatch,
|
|
ToolMessage(content="first", tool_call_id="task:first"),
|
|
second_dispatch,
|
|
ToolMessage(content="second-a", tool_call_id="task:second-a"),
|
|
ToolMessage(content="second-b", tool_call_id="task:second-b"),
|
|
AIMessage(content="done"),
|
|
]
|
|
cached_usage = {
|
|
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
|
|
}
|
|
|
|
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
|
|
monkeypatch.setattr(
|
|
task_tool_module,
|
|
"pop_cached_subagent_usage",
|
|
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
|
|
)
|
|
|
|
result = middleware.after_model({"messages": messages}, _make_runtime())
|
|
|
|
assert result is not None
|
|
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
|
|
assert len(usage_updates) == 1
|
|
updated = usage_updates[0]
|
|
assert updated.tool_calls == second_dispatch.tool_calls
|
|
assert updated.usage_metadata == {
|
|
"input_tokens": 30,
|
|
"output_tokens": 12,
|
|
"total_tokens": 42,
|
|
}
|