mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-14 04:33:45 +00:00
feat: stream subagent token usage to header via terminal task events (#2882)
* feat: real-time subagent token usage display in header and per-turn Backend: - Persist subagent token usage to AIMessage.usage_metadata via TokenUsageMiddleware, so accumulateUsage() naturally includes subagent tokens without frontend state management - Cache subagent usage by tool_call_id in task_tool, write back to the dispatching AIMessage on next model response - Emit subagent token usage on all terminal task events (task_completed, task_failed, task_cancelled, task_timed_out) - Report subagent usage to parent RunJournal for API totals - Search backward from ToolMessage to find dispatching AIMessage for correct multi-tool-call attribution Frontend: - Remove subagentUsage state, custom event handling, and prop threading — subagent tokens are now embedded in message metadata - Simplify selectHeaderTokenUsage (no subagentUsage parameter) - Per-turn inline badges show turn-specific usage via message accumulation - Remove isLoading guard from MessageTokenUsageList for dynamic updates during streaming * fix: prevent header token double counting from baseline reset race onFinish, onError, and thread-switch useEffect all reset pendingUsageBaselineMessageIdsRef to an empty Set. If thread.isLoading is still true on the next render, all messages pass the getMessagesAfterBaseline filter and their tokens are added to backendUsage (which already includes them), causing the header to display up to 2× the actual token count. Capture current message IDs instead of using an empty Set so that getMessagesAfterBaseline correctly returns no pending messages even if thread.isLoading lags behind the stream end. * fix: write back subagent tokens for all concurrent task tool calls TokenUsageMiddleware only processed messages[-2], so when a single model response dispatched multiple task tool calls only the last ToolMessage had its cached subagent usage written back to the dispatch AIMessage.usage_metadata. Earlier tasks' usage stayed in _subagent_usage_cache indefinitely (leak) and never appeared in the per-turn inline token display. Walk backward through all consecutive ToolMessages before the new AIMessage, and accumulate updates targeting the same dispatch message into one state update so overlapping writes don't clobber each other. * fix: clean up subagent usage cache entry on task cancellation When a task_tool invocation is cancelled via CancelledError, any cached subagent usage entry leaked because the TokenUsageMiddleware writeback path never fires after cancellation. Pop the cache entry before re-raising to prevent unbounded growth of the module-level _subagent_usage_cache dict. * fix: address token usage review feedback * fix: handle missing config for subagent usage cache --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
f1a0ab699a
commit
eab7ae3d62
@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
|
||||
|
||||
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
|
||||
|
||||
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
|
||||
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
|
||||
|
||||
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
|
||||
|
||||
|
||||
@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
|
||||
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
||||
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
|
||||
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
||||
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
||||
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Any, override
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
||||
return "thinking"
|
||||
|
||||
|
||||
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
|
||||
"""Return True if the AIMessage contains a tool_call with the given id."""
|
||||
for tc in message.tool_calls or []:
|
||||
if isinstance(tc, dict):
|
||||
if tc.get("id") == tool_call_id:
|
||||
return True
|
||||
elif hasattr(tc, "id") and tc.id == tool_call_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
actions: list[dict[str, Any]] = []
|
||||
@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Annotate subagent token usage onto the AIMessage that dispatched it.
|
||||
# When a task tool completes, its usage is cached by tool_call_id. Detect
|
||||
# the ToolMessage → search backward for the corresponding AIMessage → merge.
|
||||
# Walk backward through consecutive ToolMessages before the new AIMessage
|
||||
# so that multiple concurrent task tool calls all get their subagent tokens
|
||||
# written back to the same dispatch message (merging into one update).
|
||||
state_updates: dict[int, AIMessage] = {}
|
||||
if len(messages) >= 2:
|
||||
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
|
||||
|
||||
idx = len(messages) - 2
|
||||
while idx >= 0:
|
||||
tool_msg = messages[idx]
|
||||
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
|
||||
break
|
||||
|
||||
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
|
||||
if subagent_usage:
|
||||
# Search backward from the ToolMessage to find the AIMessage
|
||||
# that dispatched it. A single model response can dispatch
|
||||
# multiple task tool calls, so we can't assume a fixed offset.
|
||||
dispatch_idx = idx - 1
|
||||
while dispatch_idx >= 0:
|
||||
candidate = messages[dispatch_idx]
|
||||
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
|
||||
# Accumulate into an existing update for the same
|
||||
# AIMessage (multiple task calls in one response),
|
||||
# or merge fresh from the original message.
|
||||
existing_update = state_updates.get(dispatch_idx)
|
||||
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
|
||||
merged = {
|
||||
**prev,
|
||||
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
|
||||
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
|
||||
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
|
||||
}
|
||||
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
|
||||
break
|
||||
dispatch_idx -= 1
|
||||
idx -= 1
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
if state_updates:
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
return None
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||
|
||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||
return None
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
|
||||
|
||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
return {"messages": [updated_msg]}
|
||||
state_updates[len(messages) - 1] = updated_msg
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
|
||||
@ -26,6 +26,28 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
|
||||
# write it back to the triggering AIMessage's usage_metadata.
|
||||
_subagent_usage_cache: dict[str, dict[str, int]] = {}
|
||||
|
||||
|
||||
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
|
||||
if app_config is None:
|
||||
try:
|
||||
app_config = get_app_config()
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
||||
|
||||
|
||||
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
|
||||
if enabled and usage:
|
||||
_subagent_usage_cache[tool_call_id] = usage
|
||||
|
||||
|
||||
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
|
||||
return _subagent_usage_cache.pop(tool_call_id, None)
|
||||
|
||||
|
||||
def _is_subagent_terminal(result: Any) -> bool:
|
||||
"""Return whether a background subagent result is safe to clean up."""
|
||||
@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
|
||||
return None
|
||||
|
||||
|
||||
def _summarize_usage(records: list[dict] | None) -> dict | None:
|
||||
"""Summarize token usage records into a compact dict for SSE events."""
|
||||
if not records:
|
||||
return None
|
||||
return {
|
||||
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
|
||||
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
|
||||
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
|
||||
}
|
||||
|
||||
|
||||
def _report_subagent_usage(runtime: Any, result: Any) -> None:
|
||||
"""Report subagent token usage to the parent RunJournal, if available.
|
||||
|
||||
@ -177,6 +210,7 @@ async def task_tool(
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
"""
|
||||
runtime_app_config = _get_runtime_app_config(runtime)
|
||||
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
|
||||
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
||||
|
||||
# Get subagent configuration
|
||||
@ -312,27 +346,32 @@ async def task_tool(
|
||||
last_message_count = current_message_count
|
||||
|
||||
# Check if task completed, failed, or timed out
|
||||
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task Succeeded. Result: {result.result}"
|
||||
elif result.status == SubagentStatus.FAILED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.CANCELLED:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return "Task cancelled by user."
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task timed out. Error: {result.error}"
|
||||
@ -351,7 +390,9 @@ async def task_tool(
|
||||
timeout_minutes = config.timeout_seconds // 60
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
usage = _summarize_usage(getattr(result, "token_usage_records", None))
|
||||
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
except asyncio.CancelledError:
|
||||
# Signal the background subagent thread to stop cooperatively.
|
||||
@ -374,4 +415,8 @@ async def task_tool(
|
||||
cleanup_background_task(task_id)
|
||||
else:
|
||||
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
|
||||
_subagent_usage_cache.pop(tool_call_id, None)
|
||||
raise
|
||||
except Exception:
|
||||
_subagent_usage_cache.pop(tool_call_id, None)
|
||||
raise
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
|
||||
def test_conversation_context_has_user_id():
|
||||
@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
|
||||
|
||||
def test_queue_add_stores_user_id():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
assert len(q._queue) == 1
|
||||
assert q._queue[0].user_id == "alice"
|
||||
@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
|
||||
|
||||
def test_queue_process_passes_user_id_to_updater():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
|
||||
@ -59,12 +59,15 @@ def _make_result(
|
||||
ai_messages: list[dict] | None = None,
|
||||
result: str | None = None,
|
||||
error: str | None = None,
|
||||
token_usage_records: list[dict] | None = None,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
status=status,
|
||||
ai_messages=ai_messages or [],
|
||||
result=result,
|
||||
error=error,
|
||||
token_usage_records=token_usage_records or [],
|
||||
usage_reported=False,
|
||||
)
|
||||
|
||||
|
||||
@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
|
||||
assert len(report_calls) == 1
|
||||
assert report_calls[0][1] is cancel_result
|
||||
assert cleanup_calls == ["tc-cancel-report"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status, expected_type",
|
||||
[
|
||||
(FakeSubagentStatus.COMPLETED, "task_completed"),
|
||||
(FakeSubagentStatus.FAILED, "task_failed"),
|
||||
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
|
||||
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
|
||||
],
|
||||
)
|
||||
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
|
||||
"""Terminal task events include a usage summary from token_usage_records."""
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
|
||||
records = [
|
||||
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
|
||||
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
|
||||
]
|
||||
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-usage",
|
||||
)
|
||||
|
||||
terminal_events = [e for e in events if e["type"] == expected_type]
|
||||
assert len(terminal_events) == 1
|
||||
assert terminal_events[0]["usage"] == {
|
||||
"input_tokens": 300,
|
||||
"output_tokens": 130,
|
||||
"total_tokens": 430,
|
||||
}
|
||||
|
||||
|
||||
def test_terminal_event_usage_none_when_no_records(monkeypatch):
|
||||
"""Terminal event has usage=None when token_usage_records is empty."""
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
|
||||
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-no-records",
|
||||
)
|
||||
|
||||
completed = [e for e in events if e["type"] == "task_completed"]
|
||||
assert len(completed) == 1
|
||||
assert completed[0]["usage"] is None
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_app_config",
|
||||
MagicMock(side_effect=FileNotFoundError("missing config")),
|
||||
)
|
||||
|
||||
assert task_tool_module._token_usage_cache_enabled(None) is False
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
|
||||
runtime = _make_runtime(app_config=app_config)
|
||||
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
|
||||
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
|
||||
|
||||
task_tool_module._subagent_usage_cache.clear()
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
|
||||
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
|
||||
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-disabled-cache",
|
||||
)
|
||||
|
||||
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
|
||||
runtime = _make_runtime(app_config=app_config)
|
||||
|
||||
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
|
||||
|
||||
with pytest.raises(RuntimeError, match="poll failed"):
|
||||
_run_task_tool(
|
||||
runtime=runtime,
|
||||
description="test",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-error",
|
||||
)
|
||||
|
||||
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
"""Tests for TokenUsageMiddleware attribution annotations."""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.middlewares.token_usage_middleware import (
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY,
|
||||
@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
|
||||
"tool_call_id": "write_todos:remove",
|
||||
}
|
||||
]
|
||||
|
||||
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
|
||||
middleware = TokenUsageMiddleware()
|
||||
first_dispatch = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
|
||||
)
|
||||
second_dispatch = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "task:second-a", "name": "task", "args": {}},
|
||||
{"id": "task:second-b", "name": "task", "args": {}},
|
||||
],
|
||||
)
|
||||
messages = [
|
||||
first_dispatch,
|
||||
ToolMessage(content="first", tool_call_id="task:first"),
|
||||
second_dispatch,
|
||||
ToolMessage(content="second-a", tool_call_id="task:second-a"),
|
||||
ToolMessage(content="second-b", tool_call_id="task:second-b"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
cached_usage = {
|
||||
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
|
||||
}
|
||||
|
||||
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"pop_cached_subagent_usage",
|
||||
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
|
||||
)
|
||||
|
||||
result = middleware.after_model({"messages": messages}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
|
||||
assert len(usage_updates) == 1
|
||||
updated = usage_updates[0]
|
||||
assert updated.tool_calls == second_dispatch.tool_calls
|
||||
assert updated.usage_metadata == {
|
||||
"input_tokens": 30,
|
||||
"output_tokens": 12,
|
||||
"total_tokens": 42,
|
||||
}
|
||||
|
||||
@ -12,13 +12,11 @@ function TokenUsageSummary({
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
totalTokens,
|
||||
unavailable = false,
|
||||
}: {
|
||||
className?: string;
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
totalTokens?: number;
|
||||
unavailable?: boolean;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
|
||||
@ -33,21 +31,15 @@ function TokenUsageSummary({
|
||||
<CoinsIcon className="size-3" />
|
||||
{t.tokenUsage.label}
|
||||
</span>
|
||||
{!unavailable ? (
|
||||
<>
|
||||
<span>
|
||||
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
|
||||
</span>
|
||||
<span>
|
||||
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
|
||||
</span>
|
||||
<span className="font-medium">
|
||||
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
|
||||
</span>
|
||||
</>
|
||||
) : (
|
||||
<span>{t.tokenUsage.unavailableShort}</span>
|
||||
)}
|
||||
<span>
|
||||
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
|
||||
</span>
|
||||
<span>
|
||||
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
|
||||
</span>
|
||||
<span className="font-medium">
|
||||
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -55,7 +47,7 @@ function TokenUsageSummary({
|
||||
export function MessageTokenUsageList({
|
||||
className,
|
||||
enabled = false,
|
||||
isLoading = false,
|
||||
isLoading: _isLoading = false,
|
||||
messages,
|
||||
}: {
|
||||
className?: string;
|
||||
@ -63,7 +55,7 @@ export function MessageTokenUsageList({
|
||||
isLoading?: boolean;
|
||||
messages: Message[];
|
||||
}) {
|
||||
if (!enabled || isLoading) {
|
||||
if (!enabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -75,13 +67,16 @@ export function MessageTokenUsageList({
|
||||
|
||||
const usage = accumulateUsage(aiMessages);
|
||||
|
||||
if (!usage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<TokenUsageSummary
|
||||
className={className}
|
||||
inputTokens={usage?.inputTokens}
|
||||
outputTokens={usage?.outputTokens}
|
||||
totalTokens={usage?.totalTokens}
|
||||
unavailable={!usage}
|
||||
inputTokens={usage.inputTokens}
|
||||
outputTokens={usage.outputTokens}
|
||||
totalTokens={usage.totalTokens}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
|
||||
return hasUsage ? cumulative : null;
|
||||
}
|
||||
|
||||
function hasNonZeroUsage(
|
||||
export function hasNonZeroUsage(
|
||||
usage: TokenUsage | null | undefined,
|
||||
): usage is TokenUsage {
|
||||
return (
|
||||
@ -75,7 +75,7 @@ function hasNonZeroUsage(
|
||||
);
|
||||
}
|
||||
|
||||
function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
|
||||
export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
|
||||
return {
|
||||
inputTokens: base.inputTokens + delta.inputTokens,
|
||||
outputTokens: base.outputTokens + delta.outputTokens,
|
||||
|
||||
@ -296,7 +296,11 @@ export function useThreadStream({
|
||||
onError(error) {
|
||||
setOptimisticMessages([]);
|
||||
toast.error(getStreamErrorMessage(error));
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
messagesRef.current
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||
@ -305,7 +309,11 @@ export function useThreadStream({
|
||||
},
|
||||
onFinish(state) {
|
||||
listeners.current.onFinish?.(state.values);
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
messagesRef.current
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
@ -339,7 +347,11 @@ export function useThreadStream({
|
||||
useEffect(() => {
|
||||
startedRef.current = false;
|
||||
sendInFlightRef.current = false;
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set();
|
||||
pendingUsageBaselineMessageIdsRef.current = new Set(
|
||||
messagesRef.current
|
||||
.map(messageIdentity)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
prevHumanMsgCountRef.current =
|
||||
latestMessageCountsRef.current.humanMessageCount;
|
||||
}, [threadId]);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user