From eab7ae3d6283a51fbe759e761d39fce2308cc4a3 Mon Sep 17 00:00:00 2001 From: YuJitang Date: Wed, 13 May 2026 23:52:19 +0800 Subject: [PATCH] feat: stream subagent token usage to header via terminal task events (#2882) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: real-time subagent token usage display in header and per-turn Backend: - Persist subagent token usage to AIMessage.usage_metadata via TokenUsageMiddleware, so accumulateUsage() naturally includes subagent tokens without frontend state management - Cache subagent usage by tool_call_id in task_tool, write back to the dispatching AIMessage on next model response - Emit subagent token usage on all terminal task events (task_completed, task_failed, task_cancelled, task_timed_out) - Report subagent usage to parent RunJournal for API totals - Search backward from ToolMessage to find dispatching AIMessage for correct multi-tool-call attribution Frontend: - Remove subagentUsage state, custom event handling, and prop threading — subagent tokens are now embedded in message metadata - Simplify selectHeaderTokenUsage (no subagentUsage parameter) - Per-turn inline badges show turn-specific usage via message accumulation - Remove isLoading guard from MessageTokenUsageList for dynamic updates during streaming * fix: prevent header token double counting from baseline reset race onFinish, onError, and thread-switch useEffect all reset pendingUsageBaselineMessageIdsRef to an empty Set. If thread.isLoading is still true on the next render, all messages pass the getMessagesAfterBaseline filter and their tokens are added to backendUsage (which already includes them), causing the header to display up to 2× the actual token count. Capture current message IDs instead of using an empty Set so that getMessagesAfterBaseline correctly returns no pending messages even if thread.isLoading lags behind the stream end. * fix: write back subagent tokens for all concurrent task tool calls TokenUsageMiddleware only processed messages[-2], so when a single model response dispatched multiple task tool calls only the last ToolMessage had its cached subagent usage written back to the dispatch AIMessage.usage_metadata. Earlier tasks' usage stayed in _subagent_usage_cache indefinitely (leak) and never appeared in the per-turn inline token display. Walk backward through all consecutive ToolMessages before the new AIMessage, and accumulate updates targeting the same dispatch message into one state update so overlapping writes don't clobber each other. * fix: clean up subagent usage cache entry on task cancellation When a task_tool invocation is cancelled via CancelledError, any cached subagent usage entry leaked because the TokenUsageMiddleware writeback path never fires after cancellation. Pop the cache entry before re-raising to prevent unbounded growth of the module-level _subagent_usage_cache dict. * fix: address token usage review feedback * fix: handle missing config for subagent usage cache --------- Co-authored-by: Willem Jiang --- README.md | 2 +- backend/CLAUDE.md | 2 +- .../middlewares/token_usage_middleware.py | 61 ++++++- .../deerflow/tools/builtins/task_tool.py | 55 ++++++- .../tests/test_memory_queue_user_isolation.py | 5 +- backend/tests/test_task_tool_core_logic.py | 153 ++++++++++++++++++ backend/tests/test_token_usage_middleware.py | 49 +++++- .../messages/message-token-usage.tsx | 41 +++-- frontend/src/core/messages/usage.ts | 4 +- frontend/src/core/threads/hooks.ts | 18 ++- 10 files changed, 349 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 9ff1d501b..8248e8fe4 100644 --- a/README.md +++ b/README.md @@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl Complex tasks rarely fit in a single pass. DeerFlow decomposes them. -The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. +The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step. This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands. diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 67ee9cc7e..5e0aebfdb 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har 8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting 9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) 10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) -11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional) +11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id 12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model 13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) 14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) diff --git a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py index f59e7f2b7..0d3607faf 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py @@ -9,7 +9,7 @@ from typing import Any, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware.todo import Todo -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, ToolMessage from langgraph.runtime import Runtime logger = logging.getLogger(__name__) @@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str: return "thinking" +def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool: + """Return True if the AIMessage contains a tool_call with the given id.""" + for tc in message.tool_calls or []: + if isinstance(tc, dict): + if tc.get("id") == tool_call_id: + return True + elif hasattr(tc, "id") and tc.id == tool_call_id: + return True + return False + + def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]: tool_calls = getattr(message, "tool_calls", None) or [] actions: list[dict[str, Any]] = [] @@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware): if not messages: return None + # Annotate subagent token usage onto the AIMessage that dispatched it. + # When a task tool completes, its usage is cached by tool_call_id. Detect + # the ToolMessage → search backward for the corresponding AIMessage → merge. + # Walk backward through consecutive ToolMessages before the new AIMessage + # so that multiple concurrent task tool calls all get their subagent tokens + # written back to the same dispatch message (merging into one update). + state_updates: dict[int, AIMessage] = {} + if len(messages) >= 2: + from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage + + idx = len(messages) - 2 + while idx >= 0: + tool_msg = messages[idx] + if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id: + break + + subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id) + if subagent_usage: + # Search backward from the ToolMessage to find the AIMessage + # that dispatched it. A single model response can dispatch + # multiple task tool calls, so we can't assume a fixed offset. + dispatch_idx = idx - 1 + while dispatch_idx >= 0: + candidate = messages[dispatch_idx] + if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id): + # Accumulate into an existing update for the same + # AIMessage (multiple task calls in one response), + # or merge fresh from the original message. + existing_update = state_updates.get(dispatch_idx) + prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {}) + merged = { + **prev, + "input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"], + "output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"], + "total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"], + } + state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged}) + break + dispatch_idx -= 1 + idx -= 1 + last = messages[-1] if not isinstance(last, AIMessage): + if state_updates: + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} return None usage = getattr(last, "usage_metadata", None) @@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware): additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {}) if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution: - return None + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs}) - return {"messages": [updated_msg]} + state_updates[len(messages) - 1] = updated_msg + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 861c45b45..cf9281ff4 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -26,6 +26,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can +# write it back to the triggering AIMessage's usage_metadata. +_subagent_usage_cache: dict[str, dict[str, int]] = {} + + +def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool: + if app_config is None: + try: + app_config = get_app_config() + except FileNotFoundError: + return False + return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False)) + + +def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None: + if enabled and usage: + _subagent_usage_cache[tool_call_id] = usage + + +def pop_cached_subagent_usage(tool_call_id: str) -> dict | None: + return _subagent_usage_cache.pop(tool_call_id, None) + def _is_subagent_terminal(result: Any) -> bool: """Return whether a background subagent result is safe to clean up.""" @@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None: return None +def _summarize_usage(records: list[dict] | None) -> dict | None: + """Summarize token usage records into a compact dict for SSE events.""" + if not records: + return None + return { + "input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records), + "output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records), + "total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records), + } + + def _report_subagent_usage(runtime: Any, result: Any) -> None: """Report subagent token usage to the parent RunJournal, if available. @@ -177,6 +210,7 @@ async def task_tool( subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD. """ runtime_app_config = _get_runtime_app_config(runtime) + cache_token_usage = _token_usage_cache_enabled(runtime_app_config) available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names() # Get subagent configuration @@ -312,27 +346,32 @@ async def task_tool( last_message_count = current_message_count # Check if task completed, failed, or timed out + usage = _summarize_usage(getattr(result, "token_usage_records", None)) if result.status == SubagentStatus.COMPLETED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_completed", "task_id": task_id, "result": result.result}) + writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage}) logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") cleanup_background_task(task_id) return f"Task Succeeded. Result: {result.result}" elif result.status == SubagentStatus.FAILED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_failed", "task_id": task_id, "error": result.error}) + writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage}) logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") cleanup_background_task(task_id) return f"Task failed. Error: {result.error}" elif result.status == SubagentStatus.CANCELLED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_cancelled", "task_id": task_id, "error": result.error}) + writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage}) logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}") cleanup_background_task(task_id) return "Task cancelled by user." elif result.status == SubagentStatus.TIMED_OUT: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) + writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage}) logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") cleanup_background_task(task_id) return f"Task timed out. Error: {result.error}" @@ -351,7 +390,9 @@ async def task_tool( timeout_minutes = config.timeout_seconds // 60 logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") _report_subagent_usage(runtime, result) - writer({"type": "task_timed_out", "task_id": task_id}) + usage = _summarize_usage(getattr(result, "token_usage_records", None)) + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) + writer({"type": "task_timed_out", "task_id": task_id, "usage": usage}) return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" except asyncio.CancelledError: # Signal the background subagent thread to stop cooperatively. @@ -374,4 +415,8 @@ async def task_tool( cleanup_background_task(task_id) else: _schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count) + _subagent_usage_cache.pop(tool_call_id, None) + raise + except Exception: + _subagent_usage_cache.pop(tool_call_id, None) raise diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py index cf068e095..79250817c 100644 --- a/backend/tests/test_memory_queue_user_isolation.py +++ b/backend/tests/test_memory_queue_user_isolation.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue +from deerflow.config.memory_config import MemoryConfig def test_conversation_context_has_user_id(): @@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none(): def test_queue_add_stores_user_id(): q = MemoryUpdateQueue() - with patch.object(q, "_reset_timer"): + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): q.add(thread_id="t1", messages=["msg"], user_id="alice") assert len(q._queue) == 1 assert q._queue[0].user_id == "alice" @@ -26,7 +27,7 @@ def test_queue_add_stores_user_id(): def test_queue_process_passes_user_id_to_updater(): q = MemoryUpdateQueue() - with patch.object(q, "_reset_timer"): + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): q.add(thread_id="t1", messages=["msg"], user_id="alice") mock_updater = MagicMock() diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 0591c0e8d..658968d65 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -59,12 +59,15 @@ def _make_result( ai_messages: list[dict] | None = None, result: str | None = None, error: str | None = None, + token_usage_records: list[dict] | None = None, ) -> SimpleNamespace: return SimpleNamespace( status=status, ai_messages=ai_messages or [], result=result, error=error, + token_usage_records=token_usage_records or [], + usage_reported=False, ) @@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch): assert len(report_calls) == 1 assert report_calls[0][1] is cancel_result assert cleanup_calls == ["tc-cancel-report"] + + +@pytest.mark.parametrize( + "status, expected_type", + [ + (FakeSubagentStatus.COMPLETED, "task_completed"), + (FakeSubagentStatus.FAILED, "task_failed"), + (FakeSubagentStatus.CANCELLED, "task_cancelled"), + (FakeSubagentStatus.TIMED_OUT, "task_timed_out"), + ], +) +def test_terminal_events_include_usage(monkeypatch, status, expected_type): + """Terminal task events include a usage summary from token_usage_records.""" + config = _make_subagent_config() + runtime = _make_runtime() + events = [] + + records = [ + {"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + {"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280}, + ] + result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-usage", + ) + + terminal_events = [e for e in events if e["type"] == expected_type] + assert len(terminal_events) == 1 + assert terminal_events[0]["usage"] == { + "input_tokens": 300, + "output_tokens": 130, + "total_tokens": 430, + } + + +def test_terminal_event_usage_none_when_no_records(monkeypatch): + """Terminal event has usage=None when token_usage_records is empty.""" + config = _make_subagent_config() + runtime = _make_runtime() + events = [] + + result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[]) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-no-records", + ) + + completed = [e for e in events if e["type"] == "task_completed"] + assert len(completed) == 1 + assert completed[0]["usage"] is None + + +def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch): + monkeypatch.setattr( + task_tool_module, + "get_app_config", + MagicMock(side_effect=FileNotFoundError("missing config")), + ) + + assert task_tool_module._token_usage_cache_enabled(None) is False + + +def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch): + config = _make_subagent_config() + app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False)) + runtime = _make_runtime(app_config=app_config) + records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}] + result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records) + + task_tool_module._subagent_usage_cache.clear() + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"]) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-disabled-cache", + ) + + assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None + + +def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch): + config = _make_subagent_config() + app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True)) + runtime = _make_runtime(app_config=app_config) + + task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2} + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"]) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed"))) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + with pytest.raises(RuntimeError, match="poll failed"): + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-error", + ) + + assert task_tool_module.pop_cached_subagent_usage("tc-error") is None diff --git a/backend/tests/test_token_usage_middleware.py b/backend/tests/test_token_usage_middleware.py index b24ff7b16..9686455c0 100644 --- a/backend/tests/test_token_usage_middleware.py +++ b/backend/tests/test_token_usage_middleware.py @@ -1,9 +1,10 @@ """Tests for TokenUsageMiddleware attribution annotations.""" +import importlib import logging from unittest.mock import MagicMock -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, ToolMessage from deerflow.agents.middlewares.token_usage_middleware import ( TOKEN_USAGE_ATTRIBUTION_KEY, @@ -232,3 +233,49 @@ class TestTokenUsageMiddleware: "tool_call_id": "write_todos:remove", } ] + + def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch): + middleware = TokenUsageMiddleware() + first_dispatch = AIMessage( + content="", + tool_calls=[{"id": "task:first", "name": "task", "args": {}}], + ) + second_dispatch = AIMessage( + content="", + tool_calls=[ + {"id": "task:second-a", "name": "task", "args": {}}, + {"id": "task:second-b", "name": "task", "args": {}}, + ], + ) + messages = [ + first_dispatch, + ToolMessage(content="first", tool_call_id="task:first"), + second_dispatch, + ToolMessage(content="second-a", tool_call_id="task:second-a"), + ToolMessage(content="second-b", tool_call_id="task:second-b"), + AIMessage(content="done"), + ] + cached_usage = { + "task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + "task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27}, + } + + task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool") + monkeypatch.setattr( + task_tool_module, + "pop_cached_subagent_usage", + lambda tool_call_id: cached_usage.pop(tool_call_id, None), + ) + + result = middleware.after_model({"messages": messages}, _make_runtime()) + + assert result is not None + usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)] + assert len(usage_updates) == 1 + updated = usage_updates[0] + assert updated.tool_calls == second_dispatch.tool_calls + assert updated.usage_metadata == { + "input_tokens": 30, + "output_tokens": 12, + "total_tokens": 42, + } diff --git a/frontend/src/components/workspace/messages/message-token-usage.tsx b/frontend/src/components/workspace/messages/message-token-usage.tsx index cc8d0debb..84f8a8057 100644 --- a/frontend/src/components/workspace/messages/message-token-usage.tsx +++ b/frontend/src/components/workspace/messages/message-token-usage.tsx @@ -12,13 +12,11 @@ function TokenUsageSummary({ inputTokens, outputTokens, totalTokens, - unavailable = false, }: { className?: string; inputTokens?: number; outputTokens?: number; totalTokens?: number; - unavailable?: boolean; }) { const { t } = useI18n(); @@ -33,21 +31,15 @@ function TokenUsageSummary({ {t.tokenUsage.label} - {!unavailable ? ( - <> - - {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} - - - {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} - - - {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} - - - ) : ( - {t.tokenUsage.unavailableShort} - )} + + {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} + + + {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} + + + {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} + ); } @@ -55,7 +47,7 @@ function TokenUsageSummary({ export function MessageTokenUsageList({ className, enabled = false, - isLoading = false, + isLoading: _isLoading = false, messages, }: { className?: string; @@ -63,7 +55,7 @@ export function MessageTokenUsageList({ isLoading?: boolean; messages: Message[]; }) { - if (!enabled || isLoading) { + if (!enabled) { return null; } @@ -75,13 +67,16 @@ export function MessageTokenUsageList({ const usage = accumulateUsage(aiMessages); + if (!usage) { + return null; + } + return ( ); } diff --git a/frontend/src/core/messages/usage.ts b/frontend/src/core/messages/usage.ts index 4679dffa5..01e3a59e1 100644 --- a/frontend/src/core/messages/usage.ts +++ b/frontend/src/core/messages/usage.ts @@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null { return hasUsage ? cumulative : null; } -function hasNonZeroUsage( +export function hasNonZeroUsage( usage: TokenUsage | null | undefined, ): usage is TokenUsage { return ( @@ -75,7 +75,7 @@ function hasNonZeroUsage( ); } -function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { +export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { return { inputTokens: base.inputTokens + delta.inputTokens, outputTokens: base.outputTokens + delta.outputTokens, diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 0ac790eb2..adf9dbbb6 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -296,7 +296,11 @@ export function useThreadStream({ onError(error) { setOptimisticMessages([]); toast.error(getStreamErrorMessage(error)); - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ queryKey: threadTokenUsageQueryKey(threadIdRef.current), @@ -305,7 +309,11 @@ export function useThreadStream({ }, onFinish(state) { listeners.current.onFinish?.(state.values); - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ @@ -339,7 +347,11 @@ export function useThreadStream({ useEffect(() => { startedRef.current = false; sendInFlightRef.current = false; - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); prevHumanMsgCountRef.current = latestMessageCountsRef.current.humanMessageCount; }, [threadId]);