feat: stream subagent token usage to header via terminal task events (#2882)

* feat: real-time subagent token usage display in header and per-turn

Backend:
- Persist subagent token usage to AIMessage.usage_metadata via
  TokenUsageMiddleware, so accumulateUsage() naturally includes
  subagent tokens without frontend state management
- Cache subagent usage by tool_call_id in task_tool, write back
  to the dispatching AIMessage on next model response
- Emit subagent token usage on all terminal task events
  (task_completed, task_failed, task_cancelled, task_timed_out)
- Report subagent usage to parent RunJournal for API totals
- Search backward from ToolMessage to find dispatching AIMessage
  for correct multi-tool-call attribution

Frontend:
- Remove subagentUsage state, custom event handling, and prop
  threading — subagent tokens are now embedded in message metadata
- Simplify selectHeaderTokenUsage (no subagentUsage parameter)
- Per-turn inline badges show turn-specific usage via message
  accumulation
- Remove isLoading guard from MessageTokenUsageList for dynamic
  updates during streaming

* fix: prevent header token double counting from baseline reset race

onFinish, onError, and thread-switch useEffect all reset
pendingUsageBaselineMessageIdsRef to an empty Set. If
thread.isLoading is still true on the next render, all messages
pass the getMessagesAfterBaseline filter and their tokens are
added to backendUsage (which already includes them), causing
the header to display up to 2× the actual token count.

Capture current message IDs instead of using an empty Set so
that getMessagesAfterBaseline correctly returns no pending
messages even if thread.isLoading lags behind the stream end.

* fix: write back subagent tokens for all concurrent task tool calls

TokenUsageMiddleware only processed messages[-2], so when a
single model response dispatched multiple task tool calls only
the last ToolMessage had its cached subagent usage written back
to the dispatch AIMessage.usage_metadata. Earlier tasks' usage
stayed in _subagent_usage_cache indefinitely (leak) and never
appeared in the per-turn inline token display.

Walk backward through all consecutive ToolMessages before the
new AIMessage, and accumulate updates targeting the same
dispatch message into one state update so overlapping writes
don't clobber each other.

* fix: clean up subagent usage cache entry on task cancellation

When a task_tool invocation is cancelled via CancelledError, any
cached subagent usage entry leaked because the TokenUsageMiddleware
writeback path never fires after cancellation. Pop the cache entry
before re-raising to prevent unbounded growth of the module-level
_subagent_usage_cache dict.

* fix: address token usage review feedback

* fix: handle missing config for subagent usage cache

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
YuJitang 2026-05-13 23:52:19 +08:00 committed by GitHub
parent f1a0ab699a
commit eab7ae3d62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 349 additions and 41 deletions

View File

@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.

View File

@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)

View File

@ -9,7 +9,7 @@ from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
return "thinking"
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
"""Return True if the AIMessage contains a tool_call with the given id."""
for tc in message.tool_calls or []:
if isinstance(tc, dict):
if tc.get("id") == tool_call_id:
return True
elif hasattr(tc, "id") and tc.id == tool_call_id:
return True
return False
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
if not messages:
return None
# Annotate subagent token usage onto the AIMessage that dispatched it.
# When a task tool completes, its usage is cached by tool_call_id. Detect
# the ToolMessage → search backward for the corresponding AIMessage → merge.
# Walk backward through consecutive ToolMessages before the new AIMessage
# so that multiple concurrent task tool calls all get their subagent tokens
# written back to the same dispatch message (merging into one update).
state_updates: dict[int, AIMessage] = {}
if len(messages) >= 2:
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
idx = len(messages) - 2
while idx >= 0:
tool_msg = messages[idx]
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
break
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
if subagent_usage:
# Search backward from the ToolMessage to find the AIMessage
# that dispatched it. A single model response can dispatch
# multiple task tool calls, so we can't assume a fixed offset.
dispatch_idx = idx - 1
while dispatch_idx >= 0:
candidate = messages[dispatch_idx]
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
# Accumulate into an existing update for the same
# AIMessage (multiple task calls in one response),
# or merge fresh from the original message.
existing_update = state_updates.get(dispatch_idx)
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
merged = {
**prev,
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
}
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
break
dispatch_idx -= 1
idx -= 1
last = messages[-1]
if not isinstance(last, AIMessage):
if state_updates:
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
return None
usage = getattr(last, "usage_metadata", None)
@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return None
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
return {"messages": [updated_msg]}
state_updates[len(messages) - 1] = updated_msg
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:

View File

@ -26,6 +26,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
# write it back to the triggering AIMessage's usage_metadata.
_subagent_usage_cache: dict[str, dict[str, int]] = {}
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
if app_config is None:
try:
app_config = get_app_config()
except FileNotFoundError:
return False
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
if enabled and usage:
_subagent_usage_cache[tool_call_id] = usage
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
return _subagent_usage_cache.pop(tool_call_id, None)
def _is_subagent_terminal(result: Any) -> bool:
"""Return whether a background subagent result is safe to clean up."""
@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
return None
def _summarize_usage(records: list[dict] | None) -> dict | None:
"""Summarize token usage records into a compact dict for SSE events."""
if not records:
return None
return {
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
}
def _report_subagent_usage(runtime: Any, result: Any) -> None:
"""Report subagent token usage to the parent RunJournal, if available.
@ -177,6 +210,7 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
"""
runtime_app_config = _get_runtime_app_config(runtime)
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
# Get subagent configuration
@ -312,27 +346,32 @@ async def task_tool(
last_message_count = current_message_count
# Check if task completed, failed, or timed out
usage = _summarize_usage(getattr(result, "token_usage_records", None))
if result.status == SubagentStatus.COMPLETED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
cleanup_background_task(task_id)
return f"Task Succeeded. Result: {result.result}"
elif result.status == SubagentStatus.FAILED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
cleanup_background_task(task_id)
return f"Task failed. Error: {result.error}"
elif result.status == SubagentStatus.CANCELLED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
cleanup_background_task(task_id)
return "Task cancelled by user."
elif result.status == SubagentStatus.TIMED_OUT:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
cleanup_background_task(task_id)
return f"Task timed out. Error: {result.error}"
@ -351,7 +390,9 @@ async def task_tool(
timeout_minutes = config.timeout_seconds // 60
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
_report_subagent_usage(runtime, result)
writer({"type": "task_timed_out", "task_id": task_id})
usage = _summarize_usage(getattr(result, "token_usage_records", None))
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
except asyncio.CancelledError:
# Signal the background subagent thread to stop cooperatively.
@ -374,4 +415,8 @@ async def task_tool(
cleanup_background_task(task_id)
else:
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
_subagent_usage_cache.pop(tool_call_id, None)
raise
except Exception:
_subagent_usage_cache.pop(tool_call_id, None)
raise

View File

@ -3,6 +3,7 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def test_conversation_context_has_user_id():
@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1
assert q._queue[0].user_id == "alice"
@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
mock_updater = MagicMock()

View File

@ -59,12 +59,15 @@ def _make_result(
ai_messages: list[dict] | None = None,
result: str | None = None,
error: str | None = None,
token_usage_records: list[dict] | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
status=status,
ai_messages=ai_messages or [],
result=result,
error=error,
token_usage_records=token_usage_records or [],
usage_reported=False,
)
@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
assert len(report_calls) == 1
assert report_calls[0][1] is cancel_result
assert cleanup_calls == ["tc-cancel-report"]
@pytest.mark.parametrize(
"status, expected_type",
[
(FakeSubagentStatus.COMPLETED, "task_completed"),
(FakeSubagentStatus.FAILED, "task_failed"),
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
],
)
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
"""Terminal task events include a usage summary from token_usage_records."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
records = [
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
]
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-usage",
)
terminal_events = [e for e in events if e["type"] == expected_type]
assert len(terminal_events) == 1
assert terminal_events[0]["usage"] == {
"input_tokens": 300,
"output_tokens": 130,
"total_tokens": 430,
}
def test_terminal_event_usage_none_when_no_records(monkeypatch):
"""Terminal event has usage=None when token_usage_records is empty."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-no-records",
)
completed = [e for e in events if e["type"] == "task_completed"]
assert len(completed) == 1
assert completed[0]["usage"] is None
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
monkeypatch.setattr(
task_tool_module,
"get_app_config",
MagicMock(side_effect=FileNotFoundError("missing config")),
)
assert task_tool_module._token_usage_cache_enabled(None) is False
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
runtime = _make_runtime(app_config=app_config)
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
task_tool_module._subagent_usage_cache.clear()
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-disabled-cache",
)
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
runtime = _make_runtime(app_config=app_config)
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
with pytest.raises(RuntimeError, match="poll failed"):
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-error",
)
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None

View File

@ -1,9 +1,10 @@
"""Tests for TokenUsageMiddleware attribution annotations."""
import importlib
import logging
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, ToolMessage
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
"tool_call_id": "write_todos:remove",
}
]
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
middleware = TokenUsageMiddleware()
first_dispatch = AIMessage(
content="",
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
)
second_dispatch = AIMessage(
content="",
tool_calls=[
{"id": "task:second-a", "name": "task", "args": {}},
{"id": "task:second-b", "name": "task", "args": {}},
],
)
messages = [
first_dispatch,
ToolMessage(content="first", tool_call_id="task:first"),
second_dispatch,
ToolMessage(content="second-a", tool_call_id="task:second-a"),
ToolMessage(content="second-b", tool_call_id="task:second-b"),
AIMessage(content="done"),
]
cached_usage = {
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
}
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
monkeypatch.setattr(
task_tool_module,
"pop_cached_subagent_usage",
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
)
result = middleware.after_model({"messages": messages}, _make_runtime())
assert result is not None
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
assert len(usage_updates) == 1
updated = usage_updates[0]
assert updated.tool_calls == second_dispatch.tool_calls
assert updated.usage_metadata == {
"input_tokens": 30,
"output_tokens": 12,
"total_tokens": 42,
}

View File

@ -12,13 +12,11 @@ function TokenUsageSummary({
inputTokens,
outputTokens,
totalTokens,
unavailable = false,
}: {
className?: string;
inputTokens?: number;
outputTokens?: number;
totalTokens?: number;
unavailable?: boolean;
}) {
const { t } = useI18n();
@ -33,21 +31,15 @@ function TokenUsageSummary({
<CoinsIcon className="size-3" />
{t.tokenUsage.label}
</span>
{!unavailable ? (
<>
<span>
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
</span>
<span>
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
</span>
<span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
</span>
</>
) : (
<span>{t.tokenUsage.unavailableShort}</span>
)}
<span>
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
</span>
<span>
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
</span>
<span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
</span>
</div>
);
}
@ -55,7 +47,7 @@ function TokenUsageSummary({
export function MessageTokenUsageList({
className,
enabled = false,
isLoading = false,
isLoading: _isLoading = false,
messages,
}: {
className?: string;
@ -63,7 +55,7 @@ export function MessageTokenUsageList({
isLoading?: boolean;
messages: Message[];
}) {
if (!enabled || isLoading) {
if (!enabled) {
return null;
}
@ -75,13 +67,16 @@ export function MessageTokenUsageList({
const usage = accumulateUsage(aiMessages);
if (!usage) {
return null;
}
return (
<TokenUsageSummary
className={className}
inputTokens={usage?.inputTokens}
outputTokens={usage?.outputTokens}
totalTokens={usage?.totalTokens}
unavailable={!usage}
inputTokens={usage.inputTokens}
outputTokens={usage.outputTokens}
totalTokens={usage.totalTokens}
/>
);
}

View File

@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
return hasUsage ? cumulative : null;
}
function hasNonZeroUsage(
export function hasNonZeroUsage(
usage: TokenUsage | null | undefined,
): usage is TokenUsage {
return (
@ -75,7 +75,7 @@ function hasNonZeroUsage(
);
}
function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
return {
inputTokens: base.inputTokens + delta.inputTokens,
outputTokens: base.outputTokens + delta.outputTokens,

View File

@ -296,7 +296,11 @@ export function useThreadStream({
onError(error) {
setOptimisticMessages([]);
toast.error(getStreamErrorMessage(error));
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
@ -305,7 +309,11 @@ export function useThreadStream({
},
onFinish(state) {
listeners.current.onFinish?.(state.values);
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
@ -339,7 +347,11 @@ export function useThreadStream({
useEffect(() => {
startedRef.current = false;
sendInFlightRef.current = false;
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
prevHumanMsgCountRef.current =
latestMessageCountsRef.current.humanMessageCount;
}, [threadId]);