"""Tests for TokenUsageMiddleware attribution annotations.""" import logging from unittest.mock import MagicMock from langchain_core.messages import AIMessage from deerflow.agents.middlewares.token_usage_middleware import ( TOKEN_USAGE_ATTRIBUTION_KEY, TokenUsageMiddleware, ) def _make_runtime(): runtime = MagicMock() runtime.context = {"thread_id": "test-thread"} return runtime class TestTokenUsageMiddleware: def test_logs_cache_token_details(self, caplog): middleware = TokenUsageMiddleware() message = AIMessage( content="Here is the final answer.", usage_metadata={ "input_tokens": 350, "output_tokens": 240, "total_tokens": 590, "input_token_details": { "audio": 10, "cache_creation": 200, "cache_read": 100, }, "output_token_details": { "audio": 10, "reasoning": 200, }, }, ) with caplog.at_level( logging.INFO, logger="deerflow.agents.middlewares.token_usage_middleware", ): result = middleware.after_model({"messages": [message]}, _make_runtime()) assert result is not None assert "LLM token usage: input=350 output=240 total=590" in caplog.text assert "input_token_details={'audio': 10, 'cache_creation': 200, 'cache_read': 100}" in caplog.text assert "output_token_details={'audio': 10, 'reasoning': 200}" in caplog.text def test_logs_basic_tokens_when_no_detail_fields_in_usage_metadata(self, caplog): """When usage_metadata has only totals (no input_token_details), log just the counts.""" middleware = TokenUsageMiddleware() message = AIMessage( content="Here is the final answer.", usage_metadata={ "input_tokens": 350, "output_tokens": 240, "total_tokens": 590, }, ) with caplog.at_level( logging.INFO, logger="deerflow.agents.middlewares.token_usage_middleware", ): result = middleware.after_model({"messages": [message]}, _make_runtime()) assert result is not None assert "LLM token usage: input=350 output=240 total=590" in caplog.text assert "input_token_details" not in caplog.text def test_no_log_when_usage_metadata_is_missing(self, caplog): """When usage_metadata is absent, no token usage line is logged.""" middleware = TokenUsageMiddleware() message = AIMessage( content="Here is the final answer.", response_metadata={ "usage": { "input_tokens": 350, "output_tokens": 240, "total_tokens": 590, } }, ) with caplog.at_level( logging.INFO, logger="deerflow.agents.middlewares.token_usage_middleware", ): result = middleware.after_model({"messages": [message]}, _make_runtime()) assert result is not None assert "LLM token usage" not in caplog.text def test_annotates_todo_updates_with_structured_actions(self): middleware = TokenUsageMiddleware() message = AIMessage( content="", tool_calls=[ { "id": "write_todos:1", "name": "write_todos", "args": { "todos": [ {"content": "Inspect streaming path", "status": "completed"}, {"content": "Design token attribution schema", "status": "in_progress"}, ] }, } ], usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120}, ) state = { "messages": [message], "todos": [ {"content": "Inspect streaming path", "status": "in_progress"}, {"content": "Design token attribution schema", "status": "pending"}, ], } result = middleware.after_model(state, _make_runtime()) assert result is not None updated_message = result["messages"][0] attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] assert attribution["kind"] == "tool_batch" assert attribution["shared_attribution"] is True assert attribution["tool_call_ids"] == ["write_todos:1"] assert attribution["actions"] == [ { "kind": "todo_complete", "content": "Inspect streaming path", "tool_call_id": "write_todos:1", }, { "kind": "todo_start", "content": "Design token attribution schema", "tool_call_id": "write_todos:1", }, ] def test_annotates_subagent_and_search_steps(self): middleware = TokenUsageMiddleware() message = AIMessage( content="", tool_calls=[ { "id": "task:1", "name": "task", "args": { "description": "spec-coder patch message grouping", "subagent_type": "general-purpose", }, }, { "id": "web_search:1", "name": "web_search", "args": {"query": "LangGraph useStream messages tuple"}, }, ], ) result = middleware.after_model({"messages": [message]}, _make_runtime()) assert result is not None attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] assert attribution["kind"] == "tool_batch" assert attribution["shared_attribution"] is True assert attribution["actions"] == [ { "kind": "subagent", "description": "spec-coder patch message grouping", "subagent_type": "general-purpose", "tool_call_id": "task:1", }, { "kind": "search", "tool_name": "web_search", "query": "LangGraph useStream messages tuple", "tool_call_id": "web_search:1", }, ] def test_marks_final_answer_when_no_tools(self): middleware = TokenUsageMiddleware() message = AIMessage(content="Here is the final answer.") result = middleware.after_model({"messages": [message]}, _make_runtime()) assert result is not None attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] assert attribution["kind"] == "final_answer" assert attribution["shared_attribution"] is False assert attribution["actions"] == [] def test_annotates_removed_todos(self): middleware = TokenUsageMiddleware() message = AIMessage( content="", tool_calls=[ { "id": "write_todos:remove", "name": "write_todos", "args": { "todos": [], }, } ], ) result = middleware.after_model( { "messages": [message], "todos": [ {"content": "Archive obsolete plan", "status": "pending"}, ], }, _make_runtime(), ) assert result is not None attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] assert attribution["kind"] == "todo_update" assert attribution["shared_attribution"] is False assert attribution["actions"] == [ { "kind": "todo_remove", "content": "Archive obsolete plan", "tool_call_id": "write_todos:remove", } ]