diff --git a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py index 59c3423d4..3fd6d6132 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py @@ -1,31 +1,270 @@ -"""Middleware for logging LLM token usage.""" +"""Middleware for logging token usage and annotating step attribution.""" + +from __future__ import annotations import logging -from typing import override +from collections import defaultdict +from typing import Any, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.todo import Todo +from langchain_core.messages import AIMessage from langgraph.runtime import Runtime logger = logging.getLogger(__name__) +TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution" + + +def _string_arg(value: Any) -> str | None: + if isinstance(value, str): + normalized = value.strip() + return normalized or None + return None + + +def _normalize_todos(value: Any) -> list[Todo]: + if not isinstance(value, list): + return [] + + normalized: list[Todo] = [] + for item in value: + if not isinstance(item, dict): + continue + + todo: Todo = {} + content = _string_arg(item.get("content")) + status = item.get("status") + + if content is not None: + todo["content"] = content + if status in {"pending", "in_progress", "completed"}: + todo["status"] = status + + normalized.append(todo) + + return normalized + + +def _todo_action_kind(previous: Todo | None, current: Todo) -> str: + status = current.get("status") + previous_content = previous.get("content") if previous else None + current_content = current.get("content") + + if previous is None: + if status == "completed": + return "todo_complete" + if status == "in_progress": + return "todo_start" + return "todo_update" + + if previous_content != current_content: + return "todo_update" + + if status == "completed": + return "todo_complete" + if status == "in_progress": + return "todo_start" + return "todo_update" + + +def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]: + # This is the single source of truth for precise write_todos token + # attribution. The frontend intentionally falls back to a generic + # "Update to-do list" label when this metadata is missing or malformed. + previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list) + matched_previous_indices: set[int] = set() + + for index, todo in enumerate(previous_todos): + content = todo.get("content") + if isinstance(content, str) and content: + previous_by_content[content].append((index, todo)) + + actions: list[dict[str, Any]] = [] + + for index, todo in enumerate(next_todos): + content = todo.get("content") + if not isinstance(content, str) or not content: + continue + + previous_match: Todo | None = None + content_matches = previous_by_content.get(content) + if content_matches: + while content_matches and content_matches[0][0] in matched_previous_indices: + content_matches.pop(0) + if content_matches: + previous_index, previous_match = content_matches.pop(0) + matched_previous_indices.add(previous_index) + + if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices: + previous_match = previous_todos[index] + matched_previous_indices.add(index) + + if previous_match is not None: + previous_content = previous_match.get("content") + previous_status = previous_match.get("status") + if previous_content == content and previous_status == todo.get("status"): + continue + + actions.append( + { + "kind": _todo_action_kind(previous_match, todo), + "content": content, + } + ) + + for index, todo in enumerate(previous_todos): + if index in matched_previous_indices: + continue + + content = todo.get("content") + if not isinstance(content, str) or not content: + continue + + actions.append( + { + "kind": "todo_remove", + "content": content, + } + ) + + return actions + + +def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]: + name = _string_arg(tool_call.get("name")) or "unknown" + args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {} + tool_call_id = _string_arg(tool_call.get("id")) + + if name == "write_todos": + next_todos = _normalize_todos(args.get("todos")) + actions = _build_todo_actions(todos, next_todos) + if not actions: + return [ + { + "kind": "tool", + "tool_name": name, + "tool_call_id": tool_call_id, + } + ] + return [ + { + **action, + "tool_call_id": tool_call_id, + } + for action in actions + ] + + if name == "task": + return [ + { + "kind": "subagent", + "description": _string_arg(args.get("description")), + "subagent_type": _string_arg(args.get("subagent_type")), + "tool_call_id": tool_call_id, + } + ] + + if name in {"web_search", "image_search"}: + query = _string_arg(args.get("query")) + return [ + { + "kind": "search", + "tool_name": name, + "query": query, + "tool_call_id": tool_call_id, + } + ] + + if name == "present_files": + return [ + { + "kind": "present_files", + "tool_call_id": tool_call_id, + } + ] + + if name == "ask_clarification": + return [ + { + "kind": "clarification", + "tool_call_id": tool_call_id, + } + ] + + return [ + { + "kind": "tool", + "tool_name": name, + "description": _string_arg(args.get("description")), + "tool_call_id": tool_call_id, + } + ] + + +def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str: + if actions: + first_kind = actions[0].get("kind") + if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}: + return "todo_update" + if len(actions) == 1 and first_kind == "subagent": + return "subagent_dispatch" + return "tool_batch" + + if message.content: + return "final_answer" + return "thinking" + + +def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]: + tool_calls = getattr(message, "tool_calls", None) or [] + actions: list[dict[str, Any]] = [] + current_todos = list(todos) + + for raw_tool_call in tool_calls: + if not isinstance(raw_tool_call, dict): + continue + + described_actions = _describe_tool_call(raw_tool_call, current_todos) + actions.extend(described_actions) + + if raw_tool_call.get("name") == "write_todos": + args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {} + current_todos = _normalize_todos(args.get("todos")) + + tool_call_ids: list[str] = [] + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + + tool_call_id = _string_arg(tool_call.get("id")) + if tool_call_id is not None: + tool_call_ids.append(tool_call_id) + + return { + # Schema changes should remain additive where possible so older + # frontends can ignore unknown fields and fall back safely. + "version": 1, + "kind": _infer_step_kind(message, actions), + "shared_attribution": len(actions) > 1, + "tool_call_ids": tool_call_ids, + "actions": actions, + } + class TokenUsageMiddleware(AgentMiddleware): - """Logs token usage from model response usage_metadata.""" + """Logs token usage from model responses and annotates the AI step.""" - @override - def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: - return self._log_usage(state) - - @override - async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: - return self._log_usage(state) - - def _log_usage(self, state: AgentState) -> None: + def _apply(self, state: AgentState) -> dict | None: messages = state.get("messages", []) if not messages: return None + last = messages[-1] + if not isinstance(last, AIMessage): + return None + usage = getattr(last, "usage_metadata", None) if usage: logger.info( @@ -34,4 +273,22 @@ class TokenUsageMiddleware(AgentMiddleware): usage.get("output_tokens", "?"), usage.get("total_tokens", "?"), ) - return None + + todos = state.get("todos") or [] + attribution = _build_attribution(last, todos if isinstance(todos, list) else []) + additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {}) + + if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution: + return None + + additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution + updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs}) + return {"messages": [updated_msg]} + + @override + def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._apply(state) + + @override + async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._apply(state) diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 2ba9302cc..786e7372f 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -264,25 +264,35 @@ class DeerFlowClient: return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls] @staticmethod - def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent": - """Build a ``messages-tuple`` AI text event, attaching usage when present.""" + def _serialize_additional_kwargs(msg) -> dict[str, Any] | None: + """Copy message additional_kwargs when present.""" + additional_kwargs = getattr(msg, "additional_kwargs", None) + if isinstance(additional_kwargs, dict) and additional_kwargs: + return dict(additional_kwargs) + return None + + @staticmethod + def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent": + """Build a ``messages-tuple`` AI text event.""" data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id} if usage: data["usage_metadata"] = usage + if additional_kwargs: + data["additional_kwargs"] = additional_kwargs return StreamEvent(type="messages-tuple", data=data) @staticmethod - def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent": + def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent": """Build a ``messages-tuple`` AI tool-calls event.""" - return StreamEvent( - type="messages-tuple", - data={ - "type": "ai", - "content": "", - "id": msg_id, - "tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls), - }, - ) + data: dict[str, Any] = { + "type": "ai", + "content": "", + "id": msg_id, + "tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls), + } + if additional_kwargs: + data["additional_kwargs"] = additional_kwargs + return StreamEvent(type="messages-tuple", data=data) @staticmethod def _tool_message_event(msg: ToolMessage) -> "StreamEvent": @@ -307,19 +317,30 @@ class DeerFlowClient: d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls) if getattr(msg, "usage_metadata", None): d["usage_metadata"] = msg.usage_metadata + if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg): + d["additional_kwargs"] = additional_kwargs return d if isinstance(msg, ToolMessage): - return { + d = { "type": "tool", "content": DeerFlowClient._extract_text(msg.content), "name": getattr(msg, "name", None), "tool_call_id": getattr(msg, "tool_call_id", None), "id": getattr(msg, "id", None), } + if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg): + d["additional_kwargs"] = additional_kwargs + return d if isinstance(msg, HumanMessage): - return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)} + d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)} + if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg): + d["additional_kwargs"] = additional_kwargs + return d if isinstance(msg, SystemMessage): - return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)} + d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)} + if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg): + d["additional_kwargs"] = additional_kwargs + return d return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)} @staticmethod @@ -542,6 +563,7 @@ class DeerFlowClient: - type="messages-tuple" data={"type": "ai", "content": , "id": str} - type="messages-tuple" data={"type": "ai", "content": , "id": str, "usage_metadata": {...}} - type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]} + - type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}} - type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str} - type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}} """ @@ -564,6 +586,7 @@ class DeerFlowClient: # in both the final ``messages`` chunk and the values snapshot — # count it only on whichever arrives first. counted_usage_ids: set[str] = set() + sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {} cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} def _account_usage(msg_id: str | None, usage: Any) -> dict | None: @@ -593,6 +616,20 @@ class DeerFlowClient: "total_tokens": total_tokens, } + def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None: + if not additional_kwargs: + return None + if not msg_id: + return additional_kwargs + + sent = sent_additional_kwargs_by_id.setdefault(msg_id, {}) + delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value} + if not delta: + return None + + sent.update(delta) + return delta + for item in self._agent.stream( state, config=config, @@ -620,17 +657,31 @@ class DeerFlowClient: if isinstance(msg_chunk, AIMessage): text = self._extract_text(msg_chunk.content) + additional_kwargs = self._serialize_additional_kwargs(msg_chunk) counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata) + sent_additional_kwargs = False if text: if msg_id: streamed_ids.add(msg_id) - yield self._ai_text_event(msg_id, text, counted_usage) + additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs) + yield self._ai_text_event( + msg_id, + text, + counted_usage, + additional_kwargs_delta, + ) + sent_additional_kwargs = bool(additional_kwargs_delta) if msg_chunk.tool_calls: if msg_id: streamed_ids.add(msg_id) - yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls) + additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs) + yield self._ai_tool_calls_event( + msg_id, + msg_chunk.tool_calls, + additional_kwargs_delta, + ) elif isinstance(msg_chunk, ToolMessage): if msg_id: @@ -653,17 +704,45 @@ class DeerFlowClient: if msg_id and msg_id in streamed_ids: if isinstance(msg, AIMessage): _account_usage(msg_id, getattr(msg, "usage_metadata", None)) + additional_kwargs = self._serialize_additional_kwargs(msg) + additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs) + if additional_kwargs_delta: + # Metadata-only follow-up: ``messages-tuple`` has no + # dedicated attribution event, so clients should + # merge this empty-content AI event by message id + # and ignore it for text rendering. + yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta) continue if isinstance(msg, AIMessage): counted_usage = _account_usage(msg_id, msg.usage_metadata) + additional_kwargs = self._serialize_additional_kwargs(msg) + sent_additional_kwargs = False if msg.tool_calls: - yield self._ai_tool_calls_event(msg_id, msg.tool_calls) + additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs) + yield self._ai_tool_calls_event( + msg_id, + msg.tool_calls, + additional_kwargs_delta, + ) + sent_additional_kwargs = bool(additional_kwargs_delta) text = self._extract_text(msg.content) if text: - yield self._ai_text_event(msg_id, text, counted_usage) + additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs) + yield self._ai_text_event( + msg_id, + text, + counted_usage, + additional_kwargs_delta, + ) + elif msg_id: + additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs) + if not additional_kwargs_delta: + continue + # See the metadata-only follow-up convention above. + yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta) elif isinstance(msg, ToolMessage): yield self._tool_message_event(msg) diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index 8397af163..f0e918a08 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -437,6 +437,85 @@ class TestStream: call_kwargs = agent.stream.call_args.kwargs assert "messages" in call_kwargs["stream_mode"] + def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client): + """stream() emits a follow-up AI event when attribution metadata arrives via values.""" + assembled = AIMessage( + content="Hello!", + id="ai-1", + additional_kwargs={ + "token_usage_attribution": { + "version": 1, + "kind": "final_answer", + "shared_attribution": False, + "actions": [], + } + }, + ) + agent = MagicMock() + agent.stream.return_value = iter( + [ + ("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})), + ("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}), + ] + ) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t-stream-kwargs")) + + ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"] + assert any(event.data.get("content") == "Hello!" for event in ai_events) + assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events) + + def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client): + """stream() emits later attribution metadata even after earlier kwargs for the same id.""" + attribution = { + "version": 1, + "kind": "final_answer", + "shared_attribution": False, + "actions": [], + } + assembled = AIMessage( + content="Hello!", + id="ai-1", + additional_kwargs={ + "reasoning_content": "Thinking first.", + "token_usage_attribution": attribution, + }, + ) + agent = MagicMock() + agent.stream.return_value = iter( + [ + ( + "messages", + ( + AIMessageChunk( + content="Hello!", + id="ai-1", + additional_kwargs={"reasoning_content": "Thinking first."}, + ), + {}, + ), + ), + ("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}), + ] + ) + + with ( + patch.object(client, "_ensure_agent"), + patch.object(client, "_agent", agent), + ): + events = list(client.stream("hi", thread_id="t-stream-kwargs-delta")) + + ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"] + metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")] + + assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."} + assert metadata_events[1].data["content"] == "" + assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution} + def test_chat_accumulates_streamed_deltas(self, client): """chat() concatenates per-id deltas from messages mode.""" agent = MagicMock() diff --git a/backend/tests/test_client_message_serialization.py b/backend/tests/test_client_message_serialization.py new file mode 100644 index 000000000..de2e57a32 --- /dev/null +++ b/backend/tests/test_client_message_serialization.py @@ -0,0 +1,53 @@ +"""Tests for DeerFlowClient message serialization helpers.""" + +from langchain_core.messages import AIMessage, HumanMessage + +from deerflow.client import DeerFlowClient + + +def test_serialize_ai_message_preserves_additional_kwargs(): + message = AIMessage( + content="done", + additional_kwargs={ + "token_usage_attribution": { + "version": 1, + "kind": "final_answer", + "shared_attribution": False, + "actions": [], + } + }, + usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15}, + ) + + serialized = DeerFlowClient._serialize_message(message) + + assert serialized["type"] == "ai" + assert serialized["usage_metadata"] == { + "input_tokens": 12, + "output_tokens": 3, + "total_tokens": 15, + } + assert serialized["additional_kwargs"] == { + "token_usage_attribution": { + "version": 1, + "kind": "final_answer", + "shared_attribution": False, + "actions": [], + } + } + + +def test_serialize_human_message_preserves_additional_kwargs(): + message = HumanMessage( + content="hello", + additional_kwargs={"files": [{"name": "diagram.png"}]}, + ) + + serialized = DeerFlowClient._serialize_message(message) + + assert serialized == { + "type": "human", + "content": "hello", + "id": None, + "additional_kwargs": {"files": [{"name": "diagram.png"}]}, + } diff --git a/backend/tests/test_token_usage_middleware.py b/backend/tests/test_token_usage_middleware.py index 66a1f2229..c3b1ffc4e 100644 --- a/backend/tests/test_token_usage_middleware.py +++ b/backend/tests/test_token_usage_middleware.py @@ -1,32 +1,157 @@ -from unittest.mock import MagicMock, patch +"""Tests for TokenUsageMiddleware attribution annotations.""" + +from unittest.mock import MagicMock from langchain_core.messages import AIMessage -from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware +from deerflow.agents.middlewares.token_usage_middleware import ( + TOKEN_USAGE_ATTRIBUTION_KEY, + TokenUsageMiddleware, +) -def test_after_model_logs_usage_metadata_counts(): - middleware = TokenUsageMiddleware() - state = { - "messages": [ - AIMessage( - content="done", - usage_metadata={ - "input_tokens": 10, - "output_tokens": 5, - "total_tokens": 15, - }, - ) +def _make_runtime(): + runtime = MagicMock() + runtime.context = {"thread_id": "test-thread"} + return runtime + + +class TestTokenUsageMiddleware: + def test_annotates_todo_updates_with_structured_actions(self): + middleware = TokenUsageMiddleware() + message = AIMessage( + content="", + tool_calls=[ + { + "id": "write_todos:1", + "name": "write_todos", + "args": { + "todos": [ + {"content": "Inspect streaming path", "status": "completed"}, + {"content": "Design token attribution schema", "status": "in_progress"}, + ] + }, + } + ], + usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120}, + ) + + state = { + "messages": [message], + "todos": [ + {"content": "Inspect streaming path", "status": "in_progress"}, + {"content": "Design token attribution schema", "status": "pending"}, + ], + } + + result = middleware.after_model(state, _make_runtime()) + + assert result is not None + updated_message = result["messages"][0] + attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] + assert attribution["kind"] == "tool_batch" + assert attribution["shared_attribution"] is True + assert attribution["tool_call_ids"] == ["write_todos:1"] + assert attribution["actions"] == [ + { + "kind": "todo_complete", + "content": "Inspect streaming path", + "tool_call_id": "write_todos:1", + }, + { + "kind": "todo_start", + "content": "Design token attribution schema", + "tool_call_id": "write_todos:1", + }, ] - } - with patch("deerflow.agents.middlewares.token_usage_middleware.logger.info") as info_mock: - result = middleware.after_model(state=state, runtime=MagicMock()) + def test_annotates_subagent_and_search_steps(self): + middleware = TokenUsageMiddleware() + message = AIMessage( + content="", + tool_calls=[ + { + "id": "task:1", + "name": "task", + "args": { + "description": "spec-coder patch message grouping", + "subagent_type": "general-purpose", + }, + }, + { + "id": "web_search:1", + "name": "web_search", + "args": {"query": "LangGraph useStream messages tuple"}, + }, + ], + ) - assert result is None - info_mock.assert_called_once_with( - "LLM token usage: input=%s output=%s total=%s", - 10, - 5, - 15, - ) + result = middleware.after_model({"messages": [message]}, _make_runtime()) + + assert result is not None + attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] + assert attribution["kind"] == "tool_batch" + assert attribution["shared_attribution"] is True + assert attribution["actions"] == [ + { + "kind": "subagent", + "description": "spec-coder patch message grouping", + "subagent_type": "general-purpose", + "tool_call_id": "task:1", + }, + { + "kind": "search", + "tool_name": "web_search", + "query": "LangGraph useStream messages tuple", + "tool_call_id": "web_search:1", + }, + ] + + def test_marks_final_answer_when_no_tools(self): + middleware = TokenUsageMiddleware() + message = AIMessage(content="Here is the final answer.") + + result = middleware.after_model({"messages": [message]}, _make_runtime()) + + assert result is not None + attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] + assert attribution["kind"] == "final_answer" + assert attribution["shared_attribution"] is False + assert attribution["actions"] == [] + + def test_annotates_removed_todos(self): + middleware = TokenUsageMiddleware() + message = AIMessage( + content="", + tool_calls=[ + { + "id": "write_todos:remove", + "name": "write_todos", + "args": { + "todos": [], + }, + } + ], + ) + + result = middleware.after_model( + { + "messages": [message], + "todos": [ + {"content": "Archive obsolete plan", "status": "pending"}, + ], + }, + _make_runtime(), + ) + + assert result is not None + attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] + assert attribution["kind"] == "todo_update" + assert attribution["shared_attribution"] is False + assert attribution["actions"] == [ + { + "kind": "todo_remove", + "content": "Archive obsolete plan", + "tool_call_id": "write_todos:remove", + } + ] diff --git a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx index 198dead3c..7a456a0d3 100644 --- a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx @@ -25,7 +25,7 @@ import { useAgent } from "@/core/agents"; import { useI18n } from "@/core/i18n/hooks"; import { useModels } from "@/core/models/hooks"; import { useNotification } from "@/core/notification/hooks"; -import { useThreadSettings } from "@/core/settings"; +import { useLocalSettings, useThreadSettings } from "@/core/settings"; import { useThreadStream } from "@/core/threads/hooks"; import { textOfMessage } from "@/core/threads/utils"; import { env } from "@/env"; @@ -45,6 +45,7 @@ export default function AgentChatPage() { const { threadId, setThreadId, isNewThread, setIsNewThread } = useThreadChat(); const [settings, setSettings] = useThreadSettings(threadId); + const [localSettings, setLocalSettings] = useLocalSettings(); const { tokenUsageEnabled } = useModels(); const { showNotification } = useNotification(); @@ -100,6 +101,9 @@ export default function AgentChatPage() { ? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM + MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM : undefined; + const tokenUsageInlineMode = tokenUsageEnabled + ? localSettings.tokenUsage.inlineMode + : "off"; return ( @@ -139,6 +143,10 @@ export default function AgentChatPage() { + setLocalSettings("tokenUsage", preferences) + } /> @@ -152,10 +160,10 @@ export default function AgentChatPage() { threadId={threadId} thread={thread} paddingBottom={messageListPaddingBottom} - tokenUsageEnabled={tokenUsageEnabled} hasMoreHistory={hasMoreHistory} loadMoreHistory={loadMoreHistory} isHistoryLoading={isHistoryLoading} + tokenUsageInlineMode={tokenUsageInlineMode} /> diff --git a/frontend/src/app/workspace/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/chats/[thread_id]/page.tsx index b7c069ed6..f7b21064b 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/page.tsx @@ -24,7 +24,7 @@ import { Welcome } from "@/components/workspace/welcome"; import { useI18n } from "@/core/i18n/hooks"; import { useModels } from "@/core/models/hooks"; import { useNotification } from "@/core/notification/hooks"; -import { useThreadSettings } from "@/core/settings"; +import { useLocalSettings, useThreadSettings } from "@/core/settings"; import { useThreadStream } from "@/core/threads/hooks"; import { textOfMessage } from "@/core/threads/utils"; import { env } from "@/env"; @@ -36,6 +36,7 @@ export default function ChatPage() { const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } = useThreadChat(); const [settings, setSettings] = useThreadSettings(threadId); + const [localSettings, setLocalSettings] = useLocalSettings(); const { tokenUsageEnabled } = useModels(); const mountedRef = useRef(false); useSpecificChatMode(); @@ -99,6 +100,9 @@ export default function ChatPage() { ? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM + MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM : undefined; + const tokenUsageInlineMode = tokenUsageEnabled + ? localSettings.tokenUsage.inlineMode + : "off"; return ( @@ -119,6 +123,10 @@ export default function ChatPage() { + setLocalSettings("tokenUsage", preferences) + } /> @@ -131,10 +139,10 @@ export default function ChatPage() { threadId={threadId} thread={thread} paddingBottom={messageListPaddingBottom} - tokenUsageEnabled={tokenUsageEnabled} hasMoreHistory={hasMoreHistory} loadMoreHistory={loadMoreHistory} isHistoryLoading={isHistoryLoading} + tokenUsageInlineMode={tokenUsageInlineMode} />
diff --git a/frontend/src/components/workspace/messages/message-group.tsx b/frontend/src/components/workspace/messages/message-group.tsx index 0b2d2519c..636930bed 100644 --- a/frontend/src/components/workspace/messages/message-group.tsx +++ b/frontend/src/components/workspace/messages/message-group.tsx @@ -2,6 +2,7 @@ import type { Message } from "@langchain/langgraph-sdk"; import { BookOpenTextIcon, ChevronUp, + CoinsIcon, FolderOpenIcon, GlobeIcon, LightbulbIcon, @@ -24,6 +25,8 @@ import { import { CodeBlock } from "@/components/ai-elements/code-block"; import { Button } from "@/components/ui/button"; import { useI18n } from "@/core/i18n/hooks"; +import { formatTokenCount } from "@/core/messages/usage"; +import type { TokenDebugStep } from "@/core/messages/usage-model"; import { extractReasoningContentFromMessage, findToolCallResult, @@ -43,10 +46,14 @@ export function MessageGroup({ className, messages, isLoading = false, + tokenDebugSteps = [], + showTokenDebugSummaries = false, }: { className?: string; messages: Message[]; isLoading?: boolean; + tokenDebugSteps?: TokenDebugStep[]; + showTokenDebugSummaries?: boolean; }) { const { t } = useI18n(); const [showAbove, setShowAbove] = useState( @@ -56,6 +63,28 @@ export function MessageGroup({ env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true", ); const steps = useMemo(() => convertToSteps(messages), [messages]); + const debugStepByMessageId = useMemo( + () => + new Map( + tokenDebugSteps.map( + (step) => [step.messageId || step.id, step] as const, + ), + ), + [tokenDebugSteps], + ); + const toolCallCountByMessageId = useMemo(() => { + const counts = new Map(); + + for (const step of steps) { + if (step.type !== "toolCall" || !step.messageId) { + continue; + } + + counts.set(step.messageId, (counts.get(step.messageId) ?? 0) + 1); + } + + return counts; + }, [steps]); const lastToolCallStep = useMemo(() => { const filteredSteps = steps.filter((step) => step.type === "toolCall"); return filteredSteps[filteredSteps.length - 1]; @@ -77,6 +106,125 @@ export function MessageGroup({ } }, [lastToolCallStep, steps]); const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading); + const firstEligibleDebugSummaryStepIndexByMessageId = useMemo(() => { + const firstIndices = new Map(); + + if (!showTokenDebugSummaries) { + return firstIndices; + } + + for (const [index, step] of steps.entries()) { + const messageId = step.messageId; + if (!messageId || firstIndices.has(messageId)) { + continue; + } + + const debugStep = debugStepByMessageId.get(messageId); + if (!debugStep) { + continue; + } + + const toolCallCount = toolCallCountByMessageId.get(messageId) ?? 0; + if (!debugStep.sharedAttribution && toolCallCount > 0) { + continue; + } + if ( + !debugStep.sharedAttribution && + toolCallCount === 0 && + debugStep.label === t.common.thinking && + debugStep.secondaryLabels.length === 0 + ) { + continue; + } + + firstIndices.set(messageId, index); + } + + return firstIndices; + }, [ + debugStepByMessageId, + showTokenDebugSummaries, + steps, + t.common.thinking, + toolCallCountByMessageId, + ]); + + const renderDebugSummary = ( + messageId: string | undefined, + stepIndex: number, + ) => { + if (!showTokenDebugSummaries || !messageId) { + return null; + } + + const debugStep = debugStepByMessageId.get(messageId); + if (!debugStep) { + return null; + } + if ( + firstEligibleDebugSummaryStepIndexByMessageId.get(messageId) !== stepIndex + ) { + return null; + } + + return ( + + } + description={ + debugStep.sharedAttribution + ? t.tokenUsage.sharedAttribution + : undefined + } + > + {debugStep.secondaryLabels.length > 0 && ( + + {debugStep.secondaryLabels.map((label, index) => ( + + {label} + + ))} + + )} + + ); + }; + + const renderToolCall = ( + step: CoTToolCallStep, + options?: { isLast?: boolean }, + ) => { + const debugStep = + showTokenDebugSummaries && step.messageId + ? debugStepByMessageId.get(step.messageId) + : undefined; + + return ( + + ); + }; + + const lastReasoningDebugStep = + showTokenDebugSummaries && lastReasoningStep?.messageId + ? debugStepByMessageId.get(lastReasoningStep.messageId) + : undefined; + return ( {showAbove && - aboveLastToolCallSteps.map((step) => - step.type === "reasoning" ? ( - - } - > - ) : ( - - ), - )} + aboveLastToolCallSteps.flatMap((step) => { + const stepIndex = steps.indexOf(step); + if (step.type === "reasoning") { + return [ + renderDebugSummary(step.messageId, stepIndex), + + } + >, + ]; + } + + return [ + renderDebugSummary(step.messageId, stepIndex), + renderToolCall(step), + ]; + })} + {renderDebugSummary( + lastToolCallStep.messageId, + steps.indexOf(lastToolCallStep), + )} {lastToolCallStep && ( - + {renderToolCall(lastToolCallStep, { isLast: true })} )} )} {lastReasoningStep && ( <> + {renderDebugSummary( + lastReasoningStep.messageId, + steps.indexOf(lastReasoningStep), + )} - - -
-
{t.tokenUsage.title}
+ + + + + {t.tokenUsage.title} +
{usage ? ( - <> +
{t.tokenUsage.input} @@ -75,14 +96,53 @@ export function TokenUsageIndicator({
- +
) : ( -
+
{t.tokenUsage.unavailable}
)}
- - + + {t.tokenUsage.view} + + onPreferencesChange( + tokenUsagePreferencesFromPreset(value as TokenUsageViewPreset), + ) + } + > + {( + ["off", "summary", "per_turn", "debug"] as TokenUsageViewPreset[] + ).map((value) => { + const translationKey = presetKeyToTranslationKey(value); + return ( + +
+ {t.tokenUsage.presets[translationKey]} + + {t.tokenUsage.presetDescriptions[translationKey]} + +
+
+ ); + })} +
+ +
+ {t.tokenUsage.note} +
+
+ ); } + +function presetKeyToTranslationKey(preset: TokenUsageViewPreset) { + switch (preset) { + case "per_turn": + return "perTurn" as const; + default: + return preset; + } +} diff --git a/frontend/src/core/i18n/locales/en-US.ts b/frontend/src/core/i18n/locales/en-US.ts index 615ce808e..fbe54cc31 100644 --- a/frontend/src/core/i18n/locales/en-US.ts +++ b/frontend/src/core/i18n/locales/en-US.ts @@ -306,9 +306,32 @@ export const enUS: Translations = { input: "Input", output: "Output", total: "Total", + view: "Display", unavailable: "No token usage yet. Usage appears only after a successful model response when the provider returns usage_metadata.", unavailableShort: "No usage returned", + note: "Shown from provider-returned usage_metadata. Totals are best-effort conversation totals and may differ from provider billing pages.", + presets: { + off: "Off", + summary: "Summary", + perTurn: "Per turn", + debug: "Debug", + }, + presetDescriptions: { + off: "Hide token usage in the header and conversation.", + summary: "Show only the current conversation total in the header.", + perTurn: + "Show the header total and one token summary per assistant turn.", + debug: "Show the header total and step-level token debugging details.", + }, + finalAnswer: "Final answer", + stepTotal: "Step total", + sharedAttribution: "Shared across multiple actions in this step", + subagent: (description: string) => `Subagent: ${description}`, + startTodo: (content: string) => `Start To-do: ${content}`, + completeTodo: (content: string) => `Complete To-do: ${content}`, + updateTodo: (content: string) => `Update To-do: ${content}`, + removeTodo: (content: string) => `Remove To-do: ${content}`, }, // Shortcuts diff --git a/frontend/src/core/i18n/locales/types.ts b/frontend/src/core/i18n/locales/types.ts index b57c4be46..2eb170c0b 100644 --- a/frontend/src/core/i18n/locales/types.ts +++ b/frontend/src/core/i18n/locales/types.ts @@ -236,8 +236,30 @@ export interface Translations { input: string; output: string; total: string; + view: string; unavailable: string; unavailableShort: string; + note: string; + presets: { + off: string; + summary: string; + perTurn: string; + debug: string; + }; + presetDescriptions: { + off: string; + summary: string; + perTurn: string; + debug: string; + }; + finalAnswer: string; + stepTotal: string; + sharedAttribution: string; + subagent: (description: string) => string; + startTodo: (content: string) => string; + completeTodo: (content: string) => string; + updateTodo: (content: string) => string; + removeTodo: (content: string) => string; }; // Shortcuts diff --git a/frontend/src/core/i18n/locales/zh-CN.ts b/frontend/src/core/i18n/locales/zh-CN.ts index 544f72505..39f800c65 100644 --- a/frontend/src/core/i18n/locales/zh-CN.ts +++ b/frontend/src/core/i18n/locales/zh-CN.ts @@ -292,9 +292,31 @@ export const zhCN: Translations = { input: "输入", output: "输出", total: "总计", + view: "显示方式", unavailable: "暂无 Token 用量。只有模型成功返回且供应商提供 usage_metadata 时才会显示。", unavailableShort: "未返回用量", + note: "基于供应商返回的 usage_metadata 展示。当前总量是 best-effort 的会话参考值,可能与平台账单页不完全一致。", + presets: { + off: "关闭", + summary: "总览", + perTurn: "每轮", + debug: "调试", + }, + presetDescriptions: { + off: "隐藏顶部和会话内的 token 展示。", + summary: "只在顶部显示当前对话累计 token。", + perTurn: "显示顶部累计,并为每轮 assistant 回复显示一条汇总 token。", + debug: "显示顶部累计,并展示按步骤归类的 token 调试信息。", + }, + finalAnswer: "最终回复", + stepTotal: "步骤总计", + sharedAttribution: "该 token 由此步骤中的多个动作共同消耗", + subagent: (description: string) => `子任务:${description}`, + startTodo: (content: string) => `开始 To-do:${content}`, + completeTodo: (content: string) => `完成 To-do:${content}`, + updateTodo: (content: string) => `更新 To-do:${content}`, + removeTodo: (content: string) => `移除 To-do:${content}`, }, // Shortcuts diff --git a/frontend/src/core/messages/usage-model.ts b/frontend/src/core/messages/usage-model.ts new file mode 100644 index 000000000..229071602 --- /dev/null +++ b/frontend/src/core/messages/usage-model.ts @@ -0,0 +1,440 @@ +import type { Message } from "@langchain/langgraph-sdk"; + +import type { Translations } from "@/core/i18n/locales/types"; + +import { getUsageMetadata, type TokenUsage } from "./usage"; +import { hasContent } from "./utils"; + +export type TokenUsageInlineMode = "off" | "per_turn" | "step_debug"; + +export interface TokenUsagePreferences { + headerTotal: boolean; + inlineMode: TokenUsageInlineMode; +} + +export type TokenUsageViewPreset = "off" | "summary" | "per_turn" | "debug"; + +export interface TokenDebugStep { + id: string; + messageId: string; + label: string; + secondaryLabels: string[]; + usage: TokenUsage | null; + sharedAttribution: boolean; +} + +type TokenUsageAttributionAction = + | { + kind: "todo_start" | "todo_complete" | "todo_update" | "todo_remove"; + content?: string; + tool_call_id?: string; + } + | { + kind: "subagent"; + description?: string | null; + subagent_type?: string | null; + tool_call_id?: string; + } + | { + kind: "search"; + query?: string | null; + tool_name?: string | null; + tool_call_id?: string; + } + | { + kind: "present_files" | "clarification"; + tool_call_id?: string; + } + | { + kind: "tool"; + tool_name?: string | null; + description?: string | null; + tool_call_id?: string; + }; + +interface TokenUsageAttribution { + version?: number; + kind?: + | "thinking" + | "final_answer" + | "tool_batch" + | "todo_update" + | "subagent_dispatch"; + shared_attribution?: boolean; + tool_call_ids?: string[]; + actions?: TokenUsageAttributionAction[]; +} + +// Precise write_todos labels come from the backend attribution payload. +// The frontend fallback intentionally stays generic so we do not duplicate +// backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py +//::_build_todo_actions and risk the two diffing algorithms drifting apart. + +export function getTokenUsageViewPreset( + preferences: TokenUsagePreferences, +): TokenUsageViewPreset { + if (!preferences.headerTotal && preferences.inlineMode === "off") { + return "off"; + } + if (preferences.headerTotal && preferences.inlineMode === "off") { + return "summary"; + } + if (preferences.inlineMode === "step_debug") { + return "debug"; + } + return "per_turn"; +} + +export function tokenUsagePreferencesFromPreset( + preset: TokenUsageViewPreset, +): TokenUsagePreferences { + switch (preset) { + case "off": + return { headerTotal: false, inlineMode: "off" }; + case "summary": + return { headerTotal: true, inlineMode: "off" }; + case "debug": + return { headerTotal: true, inlineMode: "step_debug" }; + case "per_turn": + default: + return { headerTotal: true, inlineMode: "per_turn" }; + } +} + +export function buildTokenDebugSteps( + messages: Message[], + t: Translations, +): TokenDebugStep[] { + const steps: TokenDebugStep[] = []; + + for (const [index, message] of messages.entries()) { + if (message.type !== "ai") { + continue; + } + + const usage = getUsageMetadata(message); + const attribution = getTokenUsageAttribution(message); + const actionLabels: string[] = []; + + if (attribution) { + actionLabels.push(...buildActionLabelsFromAttribution(attribution, t)); + + if (actionLabels.length === 0) { + if (attribution.kind === "final_answer") { + actionLabels.push(t.tokenUsage.finalAnswer); + } else if (attribution.kind === "thinking") { + actionLabels.push(t.common.thinking); + } + } + + if (actionLabels.length > 0) { + const sharedAttribution = + attribution.shared_attribution ?? actionLabels.length > 1; + steps.push({ + id: message.id ?? `token-step-${index}`, + messageId: message.id ?? `token-step-${index}`, + label: + sharedAttribution && actionLabels.length > 1 + ? t.tokenUsage.stepTotal + : actionLabels[0]!, + secondaryLabels: + sharedAttribution && actionLabels.length > 1 ? actionLabels : [], + usage, + sharedAttribution, + }); + continue; + } + } + + for (const toolCall of message.tool_calls ?? []) { + const toolArgs = (toolCall.args ?? {}) as Record; + + if (toolCall.name === "write_todos") { + actionLabels.push(t.toolCalls.writeTodos); + continue; + } + + actionLabels.push( + describeToolCall( + { + name: toolCall.name, + args: toolArgs, + }, + t, + ), + ); + } + + if (actionLabels.length === 0) { + if (hasContent(message)) { + actionLabels.push(t.tokenUsage.finalAnswer); + } else { + actionLabels.push(t.common.thinking); + } + } + + steps.push({ + id: message.id ?? `token-step-${index}`, + messageId: message.id ?? `token-step-${index}`, + label: + actionLabels.length === 1 ? actionLabels[0]! : t.tokenUsage.stepTotal, + secondaryLabels: actionLabels.length > 1 ? actionLabels : [], + usage, + sharedAttribution: actionLabels.length > 1, + }); + } + + return steps; +} + +function getTokenUsageAttribution( + message: Message, +): TokenUsageAttribution | null { + if (message.type !== "ai") { + return null; + } + + const additionalKwargs = message.additional_kwargs; + if (!additionalKwargs || typeof additionalKwargs !== "object") { + return null; + } + + const attribution = (additionalKwargs as Record) + .token_usage_attribution; + const normalized = normalizeTokenUsageAttribution(attribution); + if (!normalized) { + return null; + } + + return normalized; +} + +function buildActionLabelsFromAttribution( + attribution: TokenUsageAttribution, + t: Translations, +): string[] { + return (attribution.actions ?? []) + .map((action) => describeAttributionAction(action, t)) + .filter((label): label is string => !!label); +} + +function describeAttributionAction( + action: TokenUsageAttributionAction, + t: Translations, +): string | null { + switch (action.kind) { + case "todo_start": + return action.content + ? t.tokenUsage.startTodo(action.content) + : t.toolCalls.writeTodos; + case "todo_complete": + return action.content + ? t.tokenUsage.completeTodo(action.content) + : t.toolCalls.writeTodos; + case "todo_update": + return action.content + ? t.tokenUsage.updateTodo(action.content) + : t.toolCalls.writeTodos; + case "todo_remove": + return action.content + ? t.tokenUsage.removeTodo(action.content) + : t.toolCalls.writeTodos; + case "subagent": + return t.tokenUsage.subagent(action.description ?? t.subtasks.subtask); + case "search": + if (action.query) { + return t.toolCalls.searchFor(action.query); + } + return t.toolCalls.useTool(action.tool_name ?? "search"); + case "present_files": + return t.toolCalls.presentFiles; + case "clarification": + return t.toolCalls.needYourHelp; + case "tool": + return describeToolCall( + { + name: action.tool_name ?? "tool", + args: action.description ? { description: action.description } : {}, + }, + t, + ); + default: + return null; + } +} + +function describeToolCall( + toolCall: { + name: string; + args: Record; + }, + t: Translations, +): string { + if (toolCall.name === "task") { + const description = + typeof toolCall.args.description === "string" + ? toolCall.args.description + : t.subtasks.subtask; + return t.tokenUsage.subagent(description); + } + + if ( + (toolCall.name === "web_search" || toolCall.name === "image_search") && + typeof toolCall.args.query === "string" + ) { + return t.toolCalls.searchFor(toolCall.args.query); + } + + if (toolCall.name === "web_fetch") { + return t.toolCalls.viewWebPage; + } + + if (toolCall.name === "present_files") { + return t.toolCalls.presentFiles; + } + + if (toolCall.name === "ask_clarification") { + return t.toolCalls.needYourHelp; + } + + if (typeof toolCall.args.description === "string") { + return toolCall.args.description; + } + + return t.toolCalls.useTool(toolCall.name); +} + +function normalizeTokenUsageAttribution( + value: unknown, +): TokenUsageAttribution | null { + const record = asRecord(value); + if (!record) { + return null; + } + + const rawActions = record.actions; + if (rawActions !== undefined && !Array.isArray(rawActions)) { + return null; + } + + return { + // Versioning is additive for now: the frontend should ignore unknown + // fields and fall back when required fields become incompatible. + version: typeof record.version === "number" ? record.version : undefined, + kind: isTokenUsageAttributionKind(record.kind) ? record.kind : undefined, + shared_attribution: + typeof record.shared_attribution === "boolean" + ? record.shared_attribution + : undefined, + tool_call_ids: Array.isArray(record.tool_call_ids) + ? record.tool_call_ids.filter( + (toolCallId): toolCallId is string => + typeof toolCallId === "string" && toolCallId.trim().length > 0, + ) + : undefined, + actions: Array.isArray(rawActions) + ? rawActions + .map((action) => normalizeTokenUsageAttributionAction(action)) + .filter( + (action): action is TokenUsageAttributionAction => action !== null, + ) + : undefined, + }; +} + +function normalizeTokenUsageAttributionAction( + value: unknown, +): TokenUsageAttributionAction | null { + const record = asRecord(value); + if (!record) { + return null; + } + + const kind = record.kind; + if ( + kind !== "todo_start" && + kind !== "todo_complete" && + kind !== "todo_update" && + kind !== "todo_remove" && + kind !== "subagent" && + kind !== "search" && + kind !== "present_files" && + kind !== "clarification" && + kind !== "tool" + ) { + return null; + } + + const content = readString(record.content); + const toolCallId = readString(record.tool_call_id); + + switch (kind) { + case "todo_start": + case "todo_complete": + case "todo_update": + case "todo_remove": + return { + kind, + content, + tool_call_id: toolCallId, + }; + case "subagent": + return { + kind, + description: readString(record.description), + subagent_type: readString(record.subagent_type), + tool_call_id: toolCallId, + }; + case "search": + return { + kind, + query: readString(record.query), + tool_name: readString(record.tool_name), + tool_call_id: toolCallId, + }; + case "present_files": + case "clarification": + return { + kind, + tool_call_id: toolCallId, + }; + case "tool": + return { + kind, + tool_name: readString(record.tool_name), + description: readString(record.description), + tool_call_id: toolCallId, + }; + default: + return null; + } +} + +function asRecord(value: unknown): Record | null { + if (!value || typeof value !== "object" || Array.isArray(value)) { + return null; + } + + return value as Record; +} + +function readString(value: unknown): string | undefined { + if (typeof value !== "string") { + return undefined; + } + + const normalized = value.trim(); + return normalized.length > 0 ? normalized : undefined; +} + +function isTokenUsageAttributionKind( + value: unknown, +): value is NonNullable { + return ( + value === "thinking" || + value === "final_answer" || + value === "tool_batch" || + value === "todo_update" || + value === "subagent_dispatch" + ); +} diff --git a/frontend/src/core/messages/utils.ts b/frontend/src/core/messages/utils.ts index 3c7d4afdc..e20daa1b6 100644 --- a/frontend/src/core/messages/utils.ts +++ b/frontend/src/core/messages/utils.ts @@ -18,7 +18,7 @@ interface AssistantClarificationGroup extends GenericMessageGroup<"assistant:cla interface AssistantSubagentGroup extends GenericMessageGroup<"assistant:subagent"> {} -type MessageGroup = +export type MessageGroup = | HumanMessageGroup | AssistantProcessingGroup | AssistantMessageGroup @@ -26,10 +26,7 @@ type MessageGroup = | AssistantClarificationGroup | AssistantSubagentGroup; -export function groupMessages( - messages: Message[], - mapper: (group: MessageGroup) => T, -): T[] { +export function getMessageGroups(messages: Message[]): MessageGroup[] { if (messages.length === 0) { return []; } @@ -124,11 +121,52 @@ export function groupMessages( } } - return groups + return groups; +} + +export function groupMessages( + messages: Message[], + mapper: (group: MessageGroup) => T, +): T[] { + return getMessageGroups(messages) .map(mapper) .filter((result) => result !== undefined && result !== null) as T[]; } +export function getAssistantTurnUsageMessages(groups: MessageGroup[]) { + const usageMessagesByGroupIndex: Array = Array.from( + { length: groups.length }, + () => null, + ); + + let turnStartIndex: number | null = null; + + for (const [index, group] of groups.entries()) { + if (group.type === "human") { + turnStartIndex = null; + continue; + } + + turnStartIndex ??= index; + + const nextGroup = groups[index + 1]; + const isTurnEnd = !nextGroup || nextGroup.type === "human"; + + if (!isTurnEnd) { + continue; + } + + usageMessagesByGroupIndex[index] = groups + .slice(turnStartIndex, index + 1) + .flatMap((currentGroup) => currentGroup.messages) + .filter((message) => message.type === "ai"); + + turnStartIndex = null; + } + + return usageMessagesByGroupIndex; +} + export function extractTextFromMessage(message: Message) { if (typeof message.content === "string") { return ( diff --git a/frontend/src/core/settings/local.ts b/frontend/src/core/settings/local.ts index bc76e8fa8..aa370c053 100644 --- a/frontend/src/core/settings/local.ts +++ b/frontend/src/core/settings/local.ts @@ -1,9 +1,14 @@ +import type { TokenUsageInlineMode } from "../messages/usage-model"; import type { AgentThreadContext } from "../threads"; export const DEFAULT_LOCAL_SETTINGS: LocalSettings = { notification: { enabled: true, }, + tokenUsage: { + headerTotal: true, + inlineMode: "per_turn", + }, context: { model_name: undefined, mode: undefined, @@ -22,6 +27,10 @@ export interface LocalSettings { notification: { enabled: boolean; }; + tokenUsage: { + headerTotal: boolean; + inlineMode: TokenUsageInlineMode; + }; context: Omit< AgentThreadContext, | "thread_id" @@ -44,6 +53,10 @@ function mergeLocalSettings(settings?: Partial): LocalSettings { ...DEFAULT_LOCAL_SETTINGS.context, ...settings?.context, }, + tokenUsage: { + ...DEFAULT_LOCAL_SETTINGS.tokenUsage, + ...settings?.tokenUsage, + }, notification: { ...DEFAULT_LOCAL_SETTINGS.notification, ...settings?.notification, diff --git a/frontend/tests/unit/core/messages/usage-model.test.ts b/frontend/tests/unit/core/messages/usage-model.test.ts new file mode 100644 index 000000000..ce6eda279 --- /dev/null +++ b/frontend/tests/unit/core/messages/usage-model.test.ts @@ -0,0 +1,396 @@ +import type { Message } from "@langchain/langgraph-sdk"; +import { expect, test } from "vitest"; + +import { enUS } from "@/core/i18n"; +import { + buildTokenDebugSteps, + getTokenUsageViewPreset, + tokenUsagePreferencesFromPreset, +} from "@/core/messages/usage-model"; + +test("maps token usage presets to persisted preferences", () => { + expect(tokenUsagePreferencesFromPreset("off")).toEqual({ + headerTotal: false, + inlineMode: "off", + }); + expect(tokenUsagePreferencesFromPreset("summary")).toEqual({ + headerTotal: true, + inlineMode: "off", + }); + expect(tokenUsagePreferencesFromPreset("per_turn")).toEqual({ + headerTotal: true, + inlineMode: "per_turn", + }); + expect(tokenUsagePreferencesFromPreset("debug")).toEqual({ + headerTotal: true, + inlineMode: "step_debug", + }); +}); + +test("derives the active preset from persisted preferences", () => { + expect( + getTokenUsageViewPreset({ + headerTotal: false, + inlineMode: "off", + }), + ).toBe("off"); + + expect( + getTokenUsageViewPreset({ + headerTotal: true, + inlineMode: "off", + }), + ).toBe("summary"); + + expect( + getTokenUsageViewPreset({ + headerTotal: true, + inlineMode: "per_turn", + }), + ).toBe("per_turn"); + + expect( + getTokenUsageViewPreset({ + headerTotal: true, + inlineMode: "step_debug", + }), + ).toBe("debug"); +}); + +test("uses generic todo labels when backend attribution is absent", () => { + const messages = [ + { + id: "ai-1", + type: "ai", + content: "", + tool_calls: [ + { + id: "write_todos:1", + name: "write_todos", + args: { + todos: [{ content: "Draft the plan", status: "in_progress" }], + }, + }, + ], + usage_metadata: { + input_tokens: 100, + output_tokens: 20, + total_tokens: 120, + }, + }, + { + id: "tool-1", + type: "tool", + name: "write_todos", + tool_call_id: "write_todos:1", + content: "ok", + }, + { + id: "ai-2", + type: "ai", + content: "", + tool_calls: [ + { + id: "write_todos:2", + name: "write_todos", + args: { + todos: [{ content: "Draft the plan", status: "completed" }], + }, + }, + ], + usage_metadata: { input_tokens: 50, output_tokens: 10, total_tokens: 60 }, + }, + { + id: "ai-3", + type: "ai", + content: "Here is the result", + usage_metadata: { input_tokens: 40, output_tokens: 15, total_tokens: 55 }, + }, + ] as Message[]; + + expect(buildTokenDebugSteps(messages, enUS)).toEqual([ + expect.objectContaining({ + messageId: "ai-1", + label: "Update to-do list", + sharedAttribution: false, + }), + expect.objectContaining({ + messageId: "ai-2", + label: "Update to-do list", + sharedAttribution: false, + }), + expect.objectContaining({ + messageId: "ai-3", + label: "Final answer", + sharedAttribution: false, + }), + ]); +}); + +test("marks multi-action AI steps as shared attribution", () => { + const messages = [ + { + id: "ai-1", + type: "ai", + content: "", + tool_calls: [ + { + id: "web_search:1", + name: "web_search", + args: { query: "LangGraph stream mode" }, + }, + { + id: "write_todos:1", + name: "write_todos", + args: { + todos: [ + { + content: "Inspect stream mode handling", + status: "in_progress", + }, + ], + }, + }, + ], + usage_metadata: { + input_tokens: 120, + output_tokens: 30, + total_tokens: 150, + }, + }, + ] as Message[]; + + expect(buildTokenDebugSteps(messages, enUS)).toEqual([ + expect.objectContaining({ + messageId: "ai-1", + label: "Step total", + sharedAttribution: true, + secondaryLabels: [ + 'Search for "LangGraph stream mode"', + "Update to-do list", + ], + }), + ]); +}); + +test("prefers backend attribution metadata when available", () => { + const messages = [ + { + id: "ai-1", + type: "ai", + content: "", + tool_calls: [ + { + id: "write_todos:1", + name: "write_todos", + args: { + todos: [ + { + content: "Fallback label should not win", + status: "in_progress", + }, + ], + }, + }, + ], + additional_kwargs: { + token_usage_attribution: { + version: 1, + kind: "todo_update", + shared_attribution: false, + actions: [{ kind: "todo_start", content: "Use backend attribution" }], + }, + }, + usage_metadata: { input_tokens: 25, output_tokens: 5, total_tokens: 30 }, + }, + ] as Message[]; + + expect(buildTokenDebugSteps(messages, enUS)).toEqual([ + expect.objectContaining({ + messageId: "ai-1", + label: "Start To-do: Use backend attribution", + sharedAttribution: false, + }), + ]); +}); + +test("falls back safely when attribution payload is malformed", () => { + const messages = [ + { + id: "ai-1", + type: "ai", + content: "", + tool_calls: [ + { + id: "web_search:1", + name: "web_search", + args: { query: "LangGraph stream mode" }, + }, + ], + additional_kwargs: { + token_usage_attribution: { + version: 1, + kind: "tool_batch", + actions: { broken: true }, + }, + }, + usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 }, + }, + ] as Message[]; + + expect(buildTokenDebugSteps(messages, enUS)).toEqual([ + expect.objectContaining({ + messageId: "ai-1", + label: 'Search for "LangGraph stream mode"', + sharedAttribution: false, + }), + ]); +}); + +test("ignores attribution actions that are not objects", () => { + const messages = [ + { + id: "ai-1", + type: "ai", + content: "", + tool_calls: [], + additional_kwargs: { + token_usage_attribution: { + version: 1, + kind: "tool_batch", + shared_attribution: true, + actions: [ + null, + "bad-action", + { kind: "search", query: "valid search", ignored: "extra-field" }, + ], + }, + }, + usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 }, + }, + ] as Message[]; + + expect(buildTokenDebugSteps(messages, enUS)).toEqual([ + expect.objectContaining({ + messageId: "ai-1", + label: 'Search for "valid search"', + }), + ]); +}); + +test("ignores malformed attribution fields and falls back to message content", () => { + const messages = [ + { + id: "ai-1", + type: "ai", + content: "Real final answer", + tool_calls: [], + additional_kwargs: { + token_usage_attribution: { + version: 1, + kind: null, + shared_attribution: null, + tool_call_ids: [null, "tool-1", 123], + actions: [{ query: "missing kind" }], + }, + }, + usage_metadata: { input_tokens: 9, output_tokens: 3, total_tokens: 12 }, + }, + ] as Message[]; + + expect(buildTokenDebugSteps(messages, enUS)).toEqual([ + expect.objectContaining({ + messageId: "ai-1", + label: "Final answer", + sharedAttribution: false, + }), + ]); +}); + +test("ignores unknown top-level attribution fields", () => { + const messages = [ + { + id: "ai-1", + type: "ai", + content: "", + tool_calls: [], + additional_kwargs: { + token_usage_attribution: { + version: 1, + kind: "tool_batch", + shared_attribution: false, + unknown_field: "ignored", + actions: [{ kind: "subagent", description: "Inspect the fix" }], + }, + }, + usage_metadata: { input_tokens: 12, output_tokens: 4, total_tokens: 16 }, + }, + ] as Message[]; + + expect(buildTokenDebugSteps(messages, enUS)).toEqual([ + expect.objectContaining({ + messageId: "ai-1", + label: "Subagent: Inspect the fix", + sharedAttribution: false, + }), + ]); +}); + +test("falls back to generic todo labels when backend attribution has no actions", () => { + const messages = [ + { + id: "ai-1", + type: "ai", + content: "", + tool_calls: [ + { + id: "write_todos:1", + name: "write_todos", + args: { + todos: [{ content: "Clean up stale tasks", status: "in_progress" }], + }, + }, + ], + usage_metadata: { + input_tokens: 100, + output_tokens: 20, + total_tokens: 120, + }, + }, + { + id: "ai-2", + type: "ai", + content: "", + tool_calls: [ + { + id: "write_todos:2", + name: "write_todos", + args: { + todos: [], + }, + }, + ], + additional_kwargs: { + token_usage_attribution: { + version: 1, + kind: "todo_update", + shared_attribution: false, + actions: [], + }, + }, + usage_metadata: { input_tokens: 30, output_tokens: 8, total_tokens: 38 }, + }, + ] as Message[]; + + expect(buildTokenDebugSteps(messages, enUS)).toEqual([ + expect.objectContaining({ + messageId: "ai-1", + label: "Update to-do list", + }), + expect.objectContaining({ + messageId: "ai-2", + label: "Update to-do list", + sharedAttribution: false, + }), + ]); +}); diff --git a/frontend/tests/unit/core/messages/utils.test.ts b/frontend/tests/unit/core/messages/utils.test.ts new file mode 100644 index 000000000..24d014c7e --- /dev/null +++ b/frontend/tests/unit/core/messages/utils.test.ts @@ -0,0 +1,65 @@ +import type { Message } from "@langchain/langgraph-sdk"; +import { expect, test } from "vitest"; + +import { + getAssistantTurnUsageMessages, + getMessageGroups, +} from "@/core/messages/utils"; + +test("aggregates token usage messages once per assistant turn", () => { + const messages = [ + { + id: "human-1", + type: "human", + content: "Plan a trip", + }, + { + id: "ai-1", + type: "ai", + content: "", + tool_calls: [{ id: "tool-1", name: "web_search", args: {} }], + usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 }, + }, + { + id: "tool-1-result", + type: "tool", + name: "web_search", + tool_call_id: "tool-1", + content: "[]", + }, + { + id: "ai-2", + type: "ai", + content: "Here is the itinerary", + usage_metadata: { input_tokens: 2, output_tokens: 8, total_tokens: 10 }, + }, + { + id: "human-2", + type: "human", + content: "Make it shorter", + }, + { + id: "ai-3", + type: "ai", + content: "Short version", + usage_metadata: { input_tokens: 1, output_tokens: 1, total_tokens: 2 }, + }, + ] as Message[]; + + const groups = getMessageGroups(messages); + const usageMessagesByGroupIndex = getAssistantTurnUsageMessages(groups); + + expect(groups.map((group) => group.type)).toEqual([ + "human", + "assistant:processing", + "assistant", + "human", + "assistant", + ]); + + expect( + usageMessagesByGroupIndex.map( + (groupMessages) => groupMessages?.map((message) => message.id) ?? null, + ), + ).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]); +});