diff --git a/README.md b/README.md
index 9ff1d501b..8248e8fe4 100644
--- a/README.md
+++ b/README.md
@@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
-The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
+The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md
index 67ee9cc7e..5e0aebfdb 100644
--- a/backend/CLAUDE.md
+++ b/backend/CLAUDE.md
@@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
-11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
+11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
diff --git a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py
index f59e7f2b7..0d3607faf 100644
--- a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py
+++ b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py
@@ -9,7 +9,7 @@ from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
-from langchain_core.messages import AIMessage
+from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
@@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
return "thinking"
+def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
+ """Return True if the AIMessage contains a tool_call with the given id."""
+ for tc in message.tool_calls or []:
+ if isinstance(tc, dict):
+ if tc.get("id") == tool_call_id:
+ return True
+ elif hasattr(tc, "id") and tc.id == tool_call_id:
+ return True
+ return False
+
+
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
@@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
if not messages:
return None
+ # Annotate subagent token usage onto the AIMessage that dispatched it.
+ # When a task tool completes, its usage is cached by tool_call_id. Detect
+ # the ToolMessage → search backward for the corresponding AIMessage → merge.
+ # Walk backward through consecutive ToolMessages before the new AIMessage
+ # so that multiple concurrent task tool calls all get their subagent tokens
+ # written back to the same dispatch message (merging into one update).
+ state_updates: dict[int, AIMessage] = {}
+ if len(messages) >= 2:
+ from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
+
+ idx = len(messages) - 2
+ while idx >= 0:
+ tool_msg = messages[idx]
+ if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
+ break
+
+ subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
+ if subagent_usage:
+ # Search backward from the ToolMessage to find the AIMessage
+ # that dispatched it. A single model response can dispatch
+ # multiple task tool calls, so we can't assume a fixed offset.
+ dispatch_idx = idx - 1
+ while dispatch_idx >= 0:
+ candidate = messages[dispatch_idx]
+ if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
+ # Accumulate into an existing update for the same
+ # AIMessage (multiple task calls in one response),
+ # or merge fresh from the original message.
+ existing_update = state_updates.get(dispatch_idx)
+ prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
+ merged = {
+ **prev,
+ "input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
+ "output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
+ "total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
+ }
+ state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
+ break
+ dispatch_idx -= 1
+ idx -= 1
+
last = messages[-1]
if not isinstance(last, AIMessage):
+ if state_updates:
+ return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
return None
usage = getattr(last, "usage_metadata", None)
@@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
- return None
+ return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
- return {"messages": [updated_msg]}
+ state_updates[len(messages) - 1] = updated_msg
+ return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py
index 861c45b45..cf9281ff4 100644
--- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py
+++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py
@@ -26,6 +26,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
+# write it back to the triggering AIMessage's usage_metadata.
+_subagent_usage_cache: dict[str, dict[str, int]] = {}
+
+
+def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
+ if app_config is None:
+ try:
+ app_config = get_app_config()
+ except FileNotFoundError:
+ return False
+ return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
+
+
+def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
+ if enabled and usage:
+ _subagent_usage_cache[tool_call_id] = usage
+
+
+def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
+ return _subagent_usage_cache.pop(tool_call_id, None)
+
def _is_subagent_terminal(result: Any) -> bool:
"""Return whether a background subagent result is safe to clean up."""
@@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
return None
+def _summarize_usage(records: list[dict] | None) -> dict | None:
+ """Summarize token usage records into a compact dict for SSE events."""
+ if not records:
+ return None
+ return {
+ "input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
+ "output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
+ "total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
+ }
+
+
def _report_subagent_usage(runtime: Any, result: Any) -> None:
"""Report subagent token usage to the parent RunJournal, if available.
@@ -177,6 +210,7 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
"""
runtime_app_config = _get_runtime_app_config(runtime)
+ cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
# Get subagent configuration
@@ -312,27 +346,32 @@ async def task_tool(
last_message_count = current_message_count
# Check if task completed, failed, or timed out
+ usage = _summarize_usage(getattr(result, "token_usage_records", None))
if result.status == SubagentStatus.COMPLETED:
+ _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
- writer({"type": "task_completed", "task_id": task_id, "result": result.result})
+ writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
cleanup_background_task(task_id)
return f"Task Succeeded. Result: {result.result}"
elif result.status == SubagentStatus.FAILED:
+ _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
- writer({"type": "task_failed", "task_id": task_id, "error": result.error})
+ writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
cleanup_background_task(task_id)
return f"Task failed. Error: {result.error}"
elif result.status == SubagentStatus.CANCELLED:
+ _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
- writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
+ writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
cleanup_background_task(task_id)
return "Task cancelled by user."
elif result.status == SubagentStatus.TIMED_OUT:
+ _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
- writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
+ writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
cleanup_background_task(task_id)
return f"Task timed out. Error: {result.error}"
@@ -351,7 +390,9 @@ async def task_tool(
timeout_minutes = config.timeout_seconds // 60
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
_report_subagent_usage(runtime, result)
- writer({"type": "task_timed_out", "task_id": task_id})
+ usage = _summarize_usage(getattr(result, "token_usage_records", None))
+ _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
+ writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
except asyncio.CancelledError:
# Signal the background subagent thread to stop cooperatively.
@@ -374,4 +415,8 @@ async def task_tool(
cleanup_background_task(task_id)
else:
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
+ _subagent_usage_cache.pop(tool_call_id, None)
+ raise
+ except Exception:
+ _subagent_usage_cache.pop(tool_call_id, None)
raise
diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py
index cf068e095..79250817c 100644
--- a/backend/tests/test_memory_queue_user_isolation.py
+++ b/backend/tests/test_memory_queue_user_isolation.py
@@ -3,6 +3,7 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
+from deerflow.config.memory_config import MemoryConfig
def test_conversation_context_has_user_id():
@@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id():
q = MemoryUpdateQueue()
- with patch.object(q, "_reset_timer"):
+ with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1
assert q._queue[0].user_id == "alice"
@@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue()
- with patch.object(q, "_reset_timer"):
+ with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
mock_updater = MagicMock()
diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py
index 0591c0e8d..658968d65 100644
--- a/backend/tests/test_task_tool_core_logic.py
+++ b/backend/tests/test_task_tool_core_logic.py
@@ -59,12 +59,15 @@ def _make_result(
ai_messages: list[dict] | None = None,
result: str | None = None,
error: str | None = None,
+ token_usage_records: list[dict] | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
status=status,
ai_messages=ai_messages or [],
result=result,
error=error,
+ token_usage_records=token_usage_records or [],
+ usage_reported=False,
)
@@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
assert len(report_calls) == 1
assert report_calls[0][1] is cancel_result
assert cleanup_calls == ["tc-cancel-report"]
+
+
+@pytest.mark.parametrize(
+ "status, expected_type",
+ [
+ (FakeSubagentStatus.COMPLETED, "task_completed"),
+ (FakeSubagentStatus.FAILED, "task_failed"),
+ (FakeSubagentStatus.CANCELLED, "task_cancelled"),
+ (FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
+ ],
+)
+def test_terminal_events_include_usage(monkeypatch, status, expected_type):
+ """Terminal task events include a usage summary from token_usage_records."""
+ config = _make_subagent_config()
+ runtime = _make_runtime()
+ events = []
+
+ records = [
+ {"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
+ {"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
+ ]
+ result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
+
+ monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
+ monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
+ monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
+ monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
+ monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
+ monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
+ monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
+ monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
+
+ _run_task_tool(
+ runtime=runtime,
+ description="test",
+ prompt="do work",
+ subagent_type="general-purpose",
+ tool_call_id="tc-usage",
+ )
+
+ terminal_events = [e for e in events if e["type"] == expected_type]
+ assert len(terminal_events) == 1
+ assert terminal_events[0]["usage"] == {
+ "input_tokens": 300,
+ "output_tokens": 130,
+ "total_tokens": 430,
+ }
+
+
+def test_terminal_event_usage_none_when_no_records(monkeypatch):
+ """Terminal event has usage=None when token_usage_records is empty."""
+ config = _make_subagent_config()
+ runtime = _make_runtime()
+ events = []
+
+ result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
+
+ monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
+ monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
+ monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
+ monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
+ monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
+ monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
+ monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
+ monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
+
+ _run_task_tool(
+ runtime=runtime,
+ description="test",
+ prompt="do work",
+ subagent_type="general-purpose",
+ tool_call_id="tc-no-records",
+ )
+
+ completed = [e for e in events if e["type"] == "task_completed"]
+ assert len(completed) == 1
+ assert completed[0]["usage"] is None
+
+
+def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
+ monkeypatch.setattr(
+ task_tool_module,
+ "get_app_config",
+ MagicMock(side_effect=FileNotFoundError("missing config")),
+ )
+
+ assert task_tool_module._token_usage_cache_enabled(None) is False
+
+
+def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
+ config = _make_subagent_config()
+ app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
+ runtime = _make_runtime(app_config=app_config)
+ records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
+ result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
+
+ task_tool_module._subagent_usage_cache.clear()
+ monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
+ monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
+ monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
+ monkeypatch.setattr(
+ task_tool_module,
+ "SubagentExecutor",
+ type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
+ )
+ monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
+ monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
+ monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
+ monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
+ monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
+
+ _run_task_tool(
+ runtime=runtime,
+ description="test",
+ prompt="do work",
+ subagent_type="general-purpose",
+ tool_call_id="tc-disabled-cache",
+ )
+
+ assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
+
+
+def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
+ config = _make_subagent_config()
+ app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
+ runtime = _make_runtime(app_config=app_config)
+
+ task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
+ monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
+ monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
+ monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
+ monkeypatch.setattr(
+ task_tool_module,
+ "SubagentExecutor",
+ type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
+ )
+ monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
+ monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
+ monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
+
+ with pytest.raises(RuntimeError, match="poll failed"):
+ _run_task_tool(
+ runtime=runtime,
+ description="test",
+ prompt="do work",
+ subagent_type="general-purpose",
+ tool_call_id="tc-error",
+ )
+
+ assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
diff --git a/backend/tests/test_token_usage_middleware.py b/backend/tests/test_token_usage_middleware.py
index b24ff7b16..9686455c0 100644
--- a/backend/tests/test_token_usage_middleware.py
+++ b/backend/tests/test_token_usage_middleware.py
@@ -1,9 +1,10 @@
"""Tests for TokenUsageMiddleware attribution annotations."""
+import importlib
import logging
from unittest.mock import MagicMock
-from langchain_core.messages import AIMessage
+from langchain_core.messages import AIMessage, ToolMessage
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
@@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
"tool_call_id": "write_todos:remove",
}
]
+
+ def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
+ middleware = TokenUsageMiddleware()
+ first_dispatch = AIMessage(
+ content="",
+ tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
+ )
+ second_dispatch = AIMessage(
+ content="",
+ tool_calls=[
+ {"id": "task:second-a", "name": "task", "args": {}},
+ {"id": "task:second-b", "name": "task", "args": {}},
+ ],
+ )
+ messages = [
+ first_dispatch,
+ ToolMessage(content="first", tool_call_id="task:first"),
+ second_dispatch,
+ ToolMessage(content="second-a", tool_call_id="task:second-a"),
+ ToolMessage(content="second-b", tool_call_id="task:second-b"),
+ AIMessage(content="done"),
+ ]
+ cached_usage = {
+ "task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
+ "task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
+ }
+
+ task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
+ monkeypatch.setattr(
+ task_tool_module,
+ "pop_cached_subagent_usage",
+ lambda tool_call_id: cached_usage.pop(tool_call_id, None),
+ )
+
+ result = middleware.after_model({"messages": messages}, _make_runtime())
+
+ assert result is not None
+ usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
+ assert len(usage_updates) == 1
+ updated = usage_updates[0]
+ assert updated.tool_calls == second_dispatch.tool_calls
+ assert updated.usage_metadata == {
+ "input_tokens": 30,
+ "output_tokens": 12,
+ "total_tokens": 42,
+ }
diff --git a/frontend/src/components/workspace/messages/message-token-usage.tsx b/frontend/src/components/workspace/messages/message-token-usage.tsx
index cc8d0debb..84f8a8057 100644
--- a/frontend/src/components/workspace/messages/message-token-usage.tsx
+++ b/frontend/src/components/workspace/messages/message-token-usage.tsx
@@ -12,13 +12,11 @@ function TokenUsageSummary({
inputTokens,
outputTokens,
totalTokens,
- unavailable = false,
}: {
className?: string;
inputTokens?: number;
outputTokens?: number;
totalTokens?: number;
- unavailable?: boolean;
}) {
const { t } = useI18n();
@@ -33,21 +31,15 @@ function TokenUsageSummary({
{t.tokenUsage.label}
- {!unavailable ? (
- <>
-
- {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
-
-
- {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
-
-
- {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
-
- >
- ) : (
- {t.tokenUsage.unavailableShort}
- )}
+
+ {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
+
+
+ {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
+
+
+ {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
+
);
}
@@ -55,7 +47,7 @@ function TokenUsageSummary({
export function MessageTokenUsageList({
className,
enabled = false,
- isLoading = false,
+ isLoading: _isLoading = false,
messages,
}: {
className?: string;
@@ -63,7 +55,7 @@ export function MessageTokenUsageList({
isLoading?: boolean;
messages: Message[];
}) {
- if (!enabled || isLoading) {
+ if (!enabled) {
return null;
}
@@ -75,13 +67,16 @@ export function MessageTokenUsageList({
const usage = accumulateUsage(aiMessages);
+ if (!usage) {
+ return null;
+ }
+
return (
);
}
diff --git a/frontend/src/core/messages/usage.ts b/frontend/src/core/messages/usage.ts
index 4679dffa5..01e3a59e1 100644
--- a/frontend/src/core/messages/usage.ts
+++ b/frontend/src/core/messages/usage.ts
@@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
return hasUsage ? cumulative : null;
}
-function hasNonZeroUsage(
+export function hasNonZeroUsage(
usage: TokenUsage | null | undefined,
): usage is TokenUsage {
return (
@@ -75,7 +75,7 @@ function hasNonZeroUsage(
);
}
-function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
+export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
return {
inputTokens: base.inputTokens + delta.inputTokens,
outputTokens: base.outputTokens + delta.outputTokens,
diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts
index 0ac790eb2..adf9dbbb6 100644
--- a/frontend/src/core/threads/hooks.ts
+++ b/frontend/src/core/threads/hooks.ts
@@ -296,7 +296,11 @@ export function useThreadStream({
onError(error) {
setOptimisticMessages([]);
toast.error(getStreamErrorMessage(error));
- pendingUsageBaselineMessageIdsRef.current = new Set();
+ pendingUsageBaselineMessageIdsRef.current = new Set(
+ messagesRef.current
+ .map(messageIdentity)
+ .filter((id): id is string => Boolean(id)),
+ );
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
@@ -305,7 +309,11 @@ export function useThreadStream({
},
onFinish(state) {
listeners.current.onFinish?.(state.values);
- pendingUsageBaselineMessageIdsRef.current = new Set();
+ pendingUsageBaselineMessageIdsRef.current = new Set(
+ messagesRef.current
+ .map(messageIdentity)
+ .filter((id): id is string => Boolean(id)),
+ );
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
@@ -339,7 +347,11 @@ export function useThreadStream({
useEffect(() => {
startedRef.current = false;
sendInFlightRef.current = false;
- pendingUsageBaselineMessageIdsRef.current = new Set();
+ pendingUsageBaselineMessageIdsRef.current = new Set(
+ messagesRef.current
+ .map(messageIdentity)
+ .filter((id): id is string => Boolean(id)),
+ );
prevHumanMsgCountRef.current =
latestMessageCountsRef.current.humanMessageCount;
}, [threadId]);