From 9892a7d46823888351bef32beddae247ffb1d294 Mon Sep 17 00:00:00 2001 From: YuJitang Date: Sun, 10 May 2026 22:47:30 +0800 Subject: [PATCH] fix: bucket subagent token usage into parent run totals (#2838) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: bucket subagent token usage into RunRow.subagent_tokens Add caller-bucketed token tracking to RunJournal so subagent and middleware LLM calls are written to the correct RunRow columns instead of all falling into lead_agent_tokens (default 0). - RunJournal: accumulate _lead_agent_tokens / _subagent_tokens / _middleware_tokens in on_llm_end, deduped by langchain run_id. Add record_external_llm_usage_records() for external sources (respects track_token_usage flag). Return caller buckets from get_completion_data(). - SubagentTokenCollector: new lightweight callback handler that collects LLM usage within subagent execution. - SubagentExecutor: wire collector into subagent run_config and sync records to SubagentResult on every chunk (timeout/cancel safe). - SubagentResult: add token_usage_records and usage_reported fields. - task_tool: report subagent usage to parent RunJournal on every terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including the CancelledError path, guarded against double-reporting. No DB migration needed — RunRow columns already exist. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix: address token usage review feedback * Address review follow-ups --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../harness/deerflow/runtime/journal.py | 68 ++++- .../harness/deerflow/subagents/executor.py | 16 ++ .../deerflow/subagents/token_collector.py | 63 +++++ .../deerflow/tools/builtins/task_tool.py | 139 +++++++--- backend/tests/test_run_journal.py | 238 ++++++++++++++++++ .../tests/test_subagent_token_collector.py | 161 ++++++++++++ backend/tests/test_task_tool_core_logic.py | 229 +++++++++++++---- frontend/tests/unit/core/threads/api.test.ts | 6 +- 8 files changed, 843 insertions(+), 77 deletions(-) create mode 100644 backend/packages/harness/deerflow/subagents/token_collector.py create mode 100644 backend/tests/test_subagent_token_collector.py diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index a0c2d029b..41e48efed 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -63,6 +63,15 @@ class RunJournal(BaseCallbackHandler): self._total_tokens = 0 self._llm_call_count = 0 + # Caller-bucketed token accumulators + self._lead_agent_tokens = 0 + self._subagent_tokens = 0 + self._middleware_tokens = 0 + + # Dedup: LangChain may fire on_llm_end multiple times for the same run_id + self._counted_llm_run_ids: set[str] = set() + self._counted_external_source_ids: set[str] = set() + # Convenience fields self._last_ai_msg: str | None = None self._first_human_msg: str | None = None @@ -214,19 +223,28 @@ class RunJournal(BaseCallbackHandler): }, ) - # Token accumulation + # Token accumulation (dedup by langchain run_id to avoid double-counting + # when the callback fires more than once for the same response) if self._track_tokens: input_tk = usage_dict.get("input_tokens", 0) or 0 output_tk = usage_dict.get("output_tokens", 0) or 0 total_tk = usage_dict.get("total_tokens", 0) or 0 if total_tk == 0: total_tk = input_tk + output_tk - if total_tk > 0: + if total_tk > 0 and rid not in self._counted_llm_run_ids: + self._counted_llm_run_ids.add(rid) self._total_input_tokens += input_tk self._total_output_tokens += output_tk self._total_tokens += total_tk self._llm_call_count += 1 + if caller.startswith("subagent:"): + self._subagent_tokens += total_tk + elif caller.startswith("middleware:"): + self._middleware_tokens += total_tk + else: + self._lead_agent_tokens += total_tk + def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: self._llm_start_times.pop(str(run_id), None) self._put(event_type="llm.error", category="trace", content=str(error)) @@ -330,6 +348,49 @@ class RunJournal(BaseCallbackHandler): # -- Public methods (called by worker) -- + def record_external_llm_usage_records( + self, + records: list[dict[str, int | str]], + ) -> None: + """Record token usage from external sources (e.g., subagents). + + Each record should contain: + source_run_id: Unique identifier to prevent double-counting + caller: Caller tag (e.g. "subagent:general-purpose") + input_tokens: Input token count + output_tokens: Output token count + total_tokens: Total token count (computed from input+output if 0/missing) + """ + if not self._track_tokens: + return + for record in records: + source_id = str(record.get("source_run_id", "")) + if not source_id: + continue + if source_id in self._counted_external_source_ids: + continue + + total_tk = record.get("total_tokens", 0) or 0 + if total_tk <= 0: + input_tk = record.get("input_tokens", 0) or 0 + output_tk = record.get("output_tokens", 0) or 0 + total_tk = input_tk + output_tk + if total_tk <= 0: + continue + + self._counted_external_source_ids.add(source_id) + self._total_input_tokens += record.get("input_tokens", 0) or 0 + self._total_output_tokens += record.get("output_tokens", 0) or 0 + self._total_tokens += total_tk + + caller = str(record.get("caller", "")) + if caller.startswith("subagent:"): + self._subagent_tokens += total_tk + elif caller.startswith("middleware:"): + self._middleware_tokens += total_tk + else: + self._lead_agent_tokens += total_tk + def set_first_human_message(self, content: str) -> None: """Record the first human message for convenience fields.""" self._first_human_msg = content[:2000] if content else None @@ -376,6 +437,9 @@ class RunJournal(BaseCallbackHandler): "total_output_tokens": self._total_output_tokens, "total_tokens": self._total_tokens, "llm_call_count": self._llm_call_count, + "lead_agent_tokens": self._lead_agent_tokens, + "subagent_tokens": self._subagent_tokens, + "middleware_tokens": self._middleware_tokens, "message_count": self._msg_count, "last_ai_message": self._last_ai_msg, "first_human_message": self._first_human_msg, diff --git a/backend/packages/harness/deerflow/subagents/executor.py b/backend/packages/harness/deerflow/subagents/executor.py index 64ba4c2c5..a2fec6432 100644 --- a/backend/packages/harness/deerflow/subagents/executor.py +++ b/backend/packages/harness/deerflow/subagents/executor.py @@ -26,6 +26,7 @@ from deerflow.models import create_chat_model from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools from deerflow.skills.types import Skill from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name +from deerflow.subagents.token_collector import SubagentTokenCollector logger = logging.getLogger(__name__) @@ -70,6 +71,8 @@ class SubagentResult: started_at: datetime | None = None completed_at: datetime | None = None ai_messages: list[dict[str, Any]] | None = None + token_usage_records: list[dict[str, int | str]] = field(default_factory=list) + usage_reported: bool = False cancel_event: threading.Event = field(default_factory=threading.Event, repr=False) def __post_init__(self): @@ -412,13 +415,20 @@ class SubagentExecutor: ai_messages = [] result.ai_messages = ai_messages + collector: SubagentTokenCollector | None = None try: state, filtered_tools = await self._build_initial_state(task) agent = self._create_agent(filtered_tools) + # Token collector for subagent LLM calls + collector_caller = f"subagent:{self.config.name}" + collector = SubagentTokenCollector(caller=collector_caller) + # Build config with thread_id for sandbox access and recursion limit run_config: RunnableConfig = { "recursion_limit": self.config.max_turns, + "callbacks": [collector], + "tags": [collector_caller], } context: dict[str, Any] = {} if self.thread_id: @@ -441,6 +451,8 @@ class SubagentExecutor: result.status = SubagentStatus.CANCELLED result.error = "Cancelled by user" result.completed_at = datetime.now() + if collector is not None: + result.token_usage_records = collector.snapshot_records() return result async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type] @@ -455,6 +467,7 @@ class SubagentExecutor: result.status = SubagentStatus.CANCELLED result.error = "Cancelled by user" result.completed_at = datetime.now() + result.token_usage_records = collector.snapshot_records() return result final_state = chunk @@ -481,6 +494,7 @@ class SubagentExecutor: logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}") logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution") + result.token_usage_records = collector.snapshot_records() if final_state is None: logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state") @@ -560,6 +574,8 @@ class SubagentExecutor: result.status = SubagentStatus.FAILED result.error = str(e) result.completed_at = datetime.now() + if collector is not None: + result.token_usage_records = collector.snapshot_records() return result diff --git a/backend/packages/harness/deerflow/subagents/token_collector.py b/backend/packages/harness/deerflow/subagents/token_collector.py new file mode 100644 index 000000000..56b419f01 --- /dev/null +++ b/backend/packages/harness/deerflow/subagents/token_collector.py @@ -0,0 +1,63 @@ +"""Callback handler that collects LLM token usage within a subagent. + +Each subagent execution creates its own collector. After the subagent +finishes, the collected records are transferred to the parent RunJournal +via :meth:`RunJournal.record_external_llm_usage_records`. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.callbacks import BaseCallbackHandler + + +class SubagentTokenCollector(BaseCallbackHandler): + """Lightweight callback handler that collects LLM token usage within a subagent.""" + + def __init__(self, caller: str): + super().__init__() + self.caller = caller + self._records: list[dict[str, int | str]] = [] + self._counted_run_ids: set[str] = set() + + def on_llm_end( + self, + response: Any, + *, + run_id: Any, + tags: list[str] | None = None, + **kwargs: Any, + ) -> None: + rid = str(run_id) + if rid in self._counted_run_ids: + return + + for generation in response.generations: + for gen in generation: + if not hasattr(gen, "message"): + continue + usage = getattr(gen.message, "usage_metadata", None) + usage_dict = dict(usage) if usage else {} + input_tk = usage_dict.get("input_tokens", 0) or 0 + output_tk = usage_dict.get("output_tokens", 0) or 0 + total_tk = usage_dict.get("total_tokens", 0) or 0 + if total_tk <= 0: + total_tk = input_tk + output_tk + if total_tk <= 0: + continue + self._counted_run_ids.add(rid) + self._records.append( + { + "source_run_id": rid, + "caller": self.caller, + "input_tokens": input_tk, + "output_tokens": output_tk, + "total_tokens": total_tk, + } + ) + return + + def snapshot_records(self) -> list[dict[str, int | str]]: + """Return a copy of the accumulated usage records.""" + return list(self._records) diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index a124e00ba..861c45b45 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -27,6 +27,92 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _is_subagent_terminal(result: Any) -> bool: + """Return whether a background subagent result is safe to clean up.""" + return result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None + + +async def _await_subagent_terminal(task_id: str, max_polls: int) -> Any | None: + """Poll until the background subagent reaches a terminal status or we run out of polls.""" + for _ in range(max_polls): + result = get_background_task_result(task_id) + if result is None: + return None + if _is_subagent_terminal(result): + return result + await asyncio.sleep(5) + return None + + +async def _deferred_cleanup_subagent_task(task_id: str, trace_id: str, max_polls: int) -> None: + """Keep polling a cancelled subagent until it can be safely removed.""" + cleanup_poll_count = 0 + while True: + result = get_background_task_result(task_id) + if result is None: + return + if _is_subagent_terminal(result): + cleanup_background_task(task_id) + return + if cleanup_poll_count >= max_polls: + logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls") + return + await asyncio.sleep(5) + cleanup_poll_count += 1 + + +def _log_cleanup_failure(cleanup_task: asyncio.Task[None], *, trace_id: str, task_id: str) -> None: + if cleanup_task.cancelled(): + return + + exc = cleanup_task.exception() + if exc is not None: + logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}") + + +def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls: int) -> None: + logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}") + cleanup_task = asyncio.create_task(_deferred_cleanup_subagent_task(task_id, trace_id, max_polls)) + cleanup_task.add_done_callback(lambda task: _log_cleanup_failure(task, trace_id=trace_id, task_id=task_id)) + + +def _find_usage_recorder(runtime: Any) -> Any | None: + """Find a callback handler with ``record_external_llm_usage_records`` in the runtime config.""" + if runtime is None: + return None + config = getattr(runtime, "config", None) + if not isinstance(config, dict): + return None + callbacks = config.get("callbacks", []) + if not callbacks: + return None + for cb in callbacks: + if hasattr(cb, "record_external_llm_usage_records"): + return cb + return None + + +def _report_subagent_usage(runtime: Any, result: Any) -> None: + """Report subagent token usage to the parent RunJournal, if available. + + Each subagent task must be reported only once (guarded by usage_reported). + """ + if getattr(result, "usage_reported", True): + return + records = getattr(result, "token_usage_records", None) or [] + if not records: + return + journal = _find_usage_recorder(runtime) + if journal is None: + logger.debug("No usage recorder found in runtime callbacks — subagent token usage not recorded") + return + try: + journal.record_external_llm_usage_records(records) + result.usage_reported = True + except Exception: + logger.warning("Failed to report subagent token usage", exc_info=True) + + def _get_runtime_app_config(runtime: Any) -> "AppConfig | None": context = getattr(runtime, "context", None) if isinstance(context, dict): @@ -227,21 +313,25 @@ async def task_tool( # Check if task completed, failed, or timed out if result.status == SubagentStatus.COMPLETED: + _report_subagent_usage(runtime, result) writer({"type": "task_completed", "task_id": task_id, "result": result.result}) logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") cleanup_background_task(task_id) return f"Task Succeeded. Result: {result.result}" elif result.status == SubagentStatus.FAILED: + _report_subagent_usage(runtime, result) writer({"type": "task_failed", "task_id": task_id, "error": result.error}) logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") cleanup_background_task(task_id) return f"Task failed. Error: {result.error}" elif result.status == SubagentStatus.CANCELLED: + _report_subagent_usage(runtime, result) writer({"type": "task_cancelled", "task_id": task_id, "error": result.error}) logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}") cleanup_background_task(task_id) return "Task cancelled by user." elif result.status == SubagentStatus.TIMED_OUT: + _report_subagent_usage(runtime, result) writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") cleanup_background_task(task_id) @@ -260,43 +350,28 @@ async def task_tool( if poll_count > max_poll_count: timeout_minutes = config.timeout_seconds // 60 logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") + _report_subagent_usage(runtime, result) writer({"type": "task_timed_out", "task_id": task_id}) return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" except asyncio.CancelledError: # Signal the background subagent thread to stop cooperatively. - # Without this, the thread (running in ThreadPoolExecutor with its - # own event loop via asyncio.run) would continue executing even - # after the parent task is cancelled. request_cancel_background_task(task_id) - async def cleanup_when_done() -> None: - max_cleanup_polls = max_poll_count - cleanup_poll_count = 0 + # Wait (shielded) for the subagent to reach a terminal state so the + # final token usage snapshot is reported to the parent RunJournal + # before the parent worker persists get_completion_data(). + terminal_result = None + try: + terminal_result = await asyncio.shield(_await_subagent_terminal(task_id, max_poll_count)) + except asyncio.CancelledError: + pass - while True: - result = get_background_task_result(task_id) - if result is None: - return - - if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None: - cleanup_background_task(task_id) - return - - if cleanup_poll_count > max_cleanup_polls: - logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls") - return - - await asyncio.sleep(5) - cleanup_poll_count += 1 - - def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None: - if cleanup_task.cancelled(): - return - - exc = cleanup_task.exception() - if exc is not None: - logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}") - - logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}") - asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure) + # Report whatever the subagent collected (even if we timed out). + final_result = terminal_result or get_background_task_result(task_id) + if final_result is not None: + _report_subagent_usage(runtime, final_result) + if final_result is not None and _is_subagent_terminal(final_result): + cleanup_background_task(task_id) + else: + _schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count) raise diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index 2188eeef0..27c05619c 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -383,6 +383,244 @@ class TestMiddlewareEvents: assert "middleware:guardrail" in event_types +class TestCallerBucketing: + """Tests for caller-bucketed token accumulation (lead_agent / subagent / middleware).""" + + def test_lead_agent_bucketing(self, journal_setup): + j, _ = journal_setup + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + assert j._lead_agent_tokens == 15 + assert j._subagent_tokens == 0 + assert j._middleware_tokens == 0 + + def test_subagent_bucketing(self, journal_setup): + j, _ = journal_setup + usage = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} + j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"]) + assert j._subagent_tokens == 30 + assert j._lead_agent_tokens == 0 + assert j._middleware_tokens == 0 + + def test_middleware_bucketing(self, journal_setup): + j, _ = journal_setup + usage = {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7} + j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarize"]) + assert j._middleware_tokens == 7 + assert j._lead_agent_tokens == 0 + assert j._subagent_tokens == 0 + + def test_mixed_callers_sum_independently(self, journal_setup): + j, _ = journal_setup + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:bash"]) + j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:title"]) + assert j._lead_agent_tokens == 15 + assert j._subagent_tokens == 15 + assert j._middleware_tokens == 15 + assert j._total_tokens == 45 + + def test_get_completion_data_includes_buckets(self, journal_setup): + j, _ = journal_setup + j._lead_agent_tokens = 100 + j._subagent_tokens = 200 + j._middleware_tokens = 50 + data = j.get_completion_data() + assert data["lead_agent_tokens"] == 100 + assert data["subagent_tokens"] == 200 + assert data["middleware_tokens"] == 50 + + def test_dedup_same_run_id(self, journal_setup): + """Same langchain run_id in on_llm_end must not double-count.""" + j, _ = journal_setup + run_id = uuid4() + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + assert j._total_tokens == 15 + assert j._lead_agent_tokens == 15 + assert j._llm_call_count == 1 + + def test_first_no_usage_second_with_usage(self, journal_setup): + """First callback with no usage must not block second callback with usage for same run_id.""" + j, _ = journal_setup + run_id = uuid4() + j.on_llm_end(_make_llm_response("A", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + assert str(run_id) not in j._counted_llm_run_ids + # Second callback for the same run_id with actual usage must still count + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + assert j._total_tokens == 15 + assert j._lead_agent_tokens == 15 + + def test_track_token_usage_false_skips_buckets(self): + """When token tracking is disabled, caller buckets stay at 0.""" + store = MemoryRunEventStore() + j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100) + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("X", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"]) + assert j._subagent_tokens == 0 + assert j._lead_agent_tokens == 0 + + def test_default_no_tags_buckets_as_lead_agent(self, journal_setup): + """LLM calls without explicit tags default to lead_agent bucket.""" + j, _ = journal_setup + usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10} + j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None) + assert j._lead_agent_tokens == 10 + assert j._subagent_tokens == 0 + assert j._middleware_tokens == 0 + + def test_unknown_tag_buckets_as_lead_agent(self, journal_setup): + """Calls with unrecognized tags (not lead_agent/subagent:/middleware:) go to lead_agent.""" + j, _ = journal_setup + usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10} + j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["some_random_tag"]) + assert j._lead_agent_tokens == 10 + + +class TestExternalUsageRecords: + """Tests for record_external_llm_usage_records.""" + + def test_records_added_to_subagent_bucket(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-1", + "caller": "subagent:general-purpose", + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + } + ] + j.record_external_llm_usage_records(records) + assert j._subagent_tokens == 150 + assert j._total_tokens == 150 + assert j._total_input_tokens == 100 + assert j._total_output_tokens == 50 + + def test_records_added_to_middleware_bucket(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-2", + "caller": "middleware:summarize", + "input_tokens": 30, + "output_tokens": 10, + "total_tokens": 40, + } + ] + j.record_external_llm_usage_records(records) + assert j._middleware_tokens == 40 + assert j._lead_agent_tokens == 0 + assert j._subagent_tokens == 0 + + def test_records_added_to_lead_agent_bucket(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-3", + "caller": "lead_agent", + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + } + ] + j.record_external_llm_usage_records(records) + assert j._lead_agent_tokens == 15 + + def test_dedup_same_source_run_id(self, journal_setup): + """Same source_run_id must not be double-counted.""" + j, _ = journal_setup + records = [ + { + "source_run_id": "dup-1", + "caller": "subagent:research", + "input_tokens": 50, + "output_tokens": 25, + "total_tokens": 75, + } + ] + j.record_external_llm_usage_records(records) + j.record_external_llm_usage_records(records) + assert j._subagent_tokens == 75 + assert j._total_tokens == 75 + + def test_total_tokens_missing_computed_from_input_output(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-4", + "caller": "subagent:bash", + "input_tokens": 200, + "output_tokens": 100, + "total_tokens": 0, + } + ] + j.record_external_llm_usage_records(records) + assert j._subagent_tokens == 300 + assert j._total_tokens == 300 + + def test_total_tokens_zero_no_count(self, journal_setup): + """Records with zero total and zero input+output must not be counted.""" + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-5", + "caller": "subagent:research", + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + ] + j.record_external_llm_usage_records(records) + assert j._total_tokens == 0 + assert j._subagent_tokens == 0 + + def test_empty_source_run_id_skipped(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "", + "caller": "subagent:research", + "input_tokens": 50, + "output_tokens": 25, + "total_tokens": 75, + } + ] + j.record_external_llm_usage_records(records) + assert j._total_tokens == 0 + + def test_multiple_records_in_single_call(self, journal_setup): + j, _ = journal_setup + records = [ + {"source_run_id": "r1", "caller": "subagent:gp", "input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + {"source_run_id": "r2", "caller": "subagent:bash", "input_tokens": 20, "output_tokens": 10, "total_tokens": 30}, + ] + j.record_external_llm_usage_records(records) + assert j._subagent_tokens == 45 + assert j._total_tokens == 45 + + def test_external_records_coexist_with_inline_callbacks(self, journal_setup): + """External records and inline on_llm_end must not interfere.""" + j, _ = journal_setup + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.record_external_llm_usage_records([{"source_run_id": "ext-6", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}]) + assert j._lead_agent_tokens == 15 + assert j._subagent_tokens == 150 + assert j._total_tokens == 165 + + def test_track_token_usage_false_skips_external_records(self): + """When token tracking is disabled, external records must not accumulate.""" + store = MemoryRunEventStore() + j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100) + j.record_external_llm_usage_records([{"source_run_id": "ext-7", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}]) + assert j._total_tokens == 0 + assert j._subagent_tokens == 0 + + class TestChatModelStartHumanMessage: """Tests for on_chat_model_start extracting the first human message.""" diff --git a/backend/tests/test_subagent_token_collector.py b/backend/tests/test_subagent_token_collector.py new file mode 100644 index 000000000..76f003760 --- /dev/null +++ b/backend/tests/test_subagent_token_collector.py @@ -0,0 +1,161 @@ +"""Tests for SubagentTokenCollector callback handler.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +from deerflow.subagents.token_collector import SubagentTokenCollector + + +def _make_llm_response(content="Hello", usage=None): + """Create a mock LLM response with a message.""" + msg = MagicMock() + msg.content = content + msg.usage_metadata = usage + + gen = MagicMock() + gen.message = msg + + response = MagicMock() + response.generations = [[gen]] + return response + + +def _make_llm_response_from_usages(usages): + """Create a mock LLM response with one generation per usage entry.""" + generations = [] + for usage in usages: + msg = MagicMock() + msg.content = "chunk" + msg.usage_metadata = usage + + gen = MagicMock() + gen.message = msg + generations.append([gen]) + + response = MagicMock() + response.generations = generations + return response + + +class TestSubagentTokenCollector: + def test_collects_usage_from_response(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 1 + assert records[0]["caller"] == "subagent:test" + assert records[0]["input_tokens"] == 100 + assert records[0]["output_tokens"] == 50 + assert records[0]["total_tokens"] == 150 + assert "source_run_id" in records[0] + + def test_total_tokens_zero_uses_input_plus_output(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 200, "output_tokens": 100, "total_tokens": 0} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 1 + assert records[0]["total_tokens"] == 300 + + def test_total_tokens_missing_uses_input_plus_output(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 30, "output_tokens": 20} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 1 + assert records[0]["total_tokens"] == 50 + + def test_dedup_same_run_id(self): + collector = SubagentTokenCollector(caller="subagent:test") + run_id = uuid4() + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id) + collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id) + records = collector.snapshot_records() + assert len(records) == 1 + + def test_no_usage_no_record(self): + collector = SubagentTokenCollector(caller="subagent:test") + collector.on_llm_end(_make_llm_response("Hi", usage=None), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 0 + + def test_zero_usage_no_record(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 0 + + def test_skips_empty_generation_and_records_later_usage(self): + collector = SubagentTokenCollector(caller="subagent:test") + response = _make_llm_response_from_usages( + [ + None, + {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}, + ] + ) + + collector.on_llm_end(response, run_id=uuid4()) + + records = collector.snapshot_records() + assert len(records) == 1 + assert records[0]["total_tokens"] == 30 + + def test_snapshot_returns_copy(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + snap1 = collector.snapshot_records() + snap2 = collector.snapshot_records() + assert snap1 == snap2 + assert snap1 is not snap2 + # Mutating snapshot does not affect internal records + snap1.append({"source_run_id": "fake"}) + assert len(collector.snapshot_records()) == 1 + + def test_multiple_calls_accumulate(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4()) + collector.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 2 + + def test_different_run_ids_accumulate_separately(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} + collector.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4()) + collector.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 2 + assert records[0]["total_tokens"] == 15 + assert records[1]["total_tokens"] == 30 + + def test_message_without_usage_metadata_skipped(self): + """A response where message has no usage_metadata attribute must be skipped.""" + collector = SubagentTokenCollector(caller="subagent:test") + + msg = MagicMock(spec=[]) # object without usage_metadata + gen = MagicMock() + gen.message = msg + response = MagicMock() + response.generations = [[gen]] + + collector.on_llm_end(response, run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 0 + + def test_generation_without_message_skipped(self): + """A generation without a message attribute must be skipped.""" + collector = SubagentTokenCollector(caller="subagent:test") + + gen = MagicMock(spec=[]) # object without message + response = MagicMock() + response.generations = [[gen]] + + collector.on_llm_end(response, run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 0 diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 3be1e4b5c..0591c0e8d 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -777,22 +777,27 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch): def test_cleanup_scheduled_on_cancellation(monkeypatch): - """Verify cancellation schedules deferred cleanup for the background task.""" + """Verify cancellation handler synchronously cleans up after shielded wait.""" config = _make_subagent_config() events = [] cleanup_calls = [] - scheduled_cleanup_coros = [] poll_count = 0 def get_result(_: str): nonlocal poll_count poll_count += 1 - if poll_count == 1: + # Main loop polls RUNNING twice, then shielded wait gets COMPLETED + if poll_count <= 2: return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]) return _make_result(FakeSubagentStatus.COMPLETED, result="done") - async def cancel_on_first_sleep(_: float) -> None: - raise asyncio.CancelledError + sleep_count = 0 + + async def cancel_on_second_sleep(_: float) -> None: + nonlocal sleep_count + sleep_count += 1 + if sleep_count == 2: + raise asyncio.CancelledError monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) monkeypatch.setattr( @@ -804,12 +809,7 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch): monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep) - monkeypatch.setattr( - task_tool_module.asyncio, - "create_task", - lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(), - ) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_second_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -826,25 +826,48 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch): tool_call_id="tc-cancelled-cleanup", ) - assert cleanup_calls == [] - assert len(scheduled_cleanup_coros) == 1 - - asyncio.run(scheduled_cleanup_coros.pop()) - + # Cleanup happens synchronously within the cancellation handler assert cleanup_calls == ["tc-cancelled-cleanup"] def test_cancelled_cleanup_stops_after_timeout(monkeypatch): - """Verify deferred cleanup gives up after a bounded number of polls.""" + """Verify cancellation handler survives a shielded-wait timeout gracefully. + + When the subagent never reaches a terminal state, the shielded wait times + out (or is interrupted), the handler reports whatever usage it can, calls + cleanup (which is a no-op for non-terminal tasks), and re-raises. + """ config = _make_subagent_config() - config.timeout_seconds = 1 events = [] + report_calls = [] cleanup_calls = [] - scheduled_cleanup_coros = [] + scheduled_cleanups = [] + + # Always return RUNNING — subagent never finishes + monkeypatch.setattr( + task_tool_module, + "get_background_task_result", + lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), + ) async def cancel_on_first_sleep(_: float) -> None: raise asyncio.CancelledError + def fake_report_subagent_usage(runtime, result): + report_calls.append((runtime, result)) + + class DummyCleanupTask: + def __init__(self, coro): + self.coro = coro + + def add_done_callback(self, callback): + self.callback = callback + + def fake_create_task(coro): + scheduled_cleanups.append(coro) + coro.close() + return DummyCleanupTask(coro) + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) monkeypatch.setattr( task_tool_module, @@ -852,19 +875,10 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch): type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), ) monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) - - monkeypatch.setattr( - task_tool_module, - "get_background_task_result", - lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), - ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep) - monkeypatch.setattr( - task_tool_module.asyncio, - "create_task", - lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(), - ) + monkeypatch.setattr(task_tool_module.asyncio, "create_task", fake_create_task) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -881,13 +895,73 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch): tool_call_id="tc-cancelled-timeout", ) - async def bounded_sleep(_seconds: float) -> None: - return None - - monkeypatch.setattr(task_tool_module.asyncio, "sleep", bounded_sleep) - asyncio.run(scheduled_cleanup_coros.pop()) - + # Non-terminal tasks cannot be cleaned immediately; a deferred cleanup + # keeps polling after the parent cancellation path exits. assert cleanup_calls == [] + assert len(scheduled_cleanups) == 1 + # _report_subagent_usage is called (but skips because result has no records) + assert len(report_calls) == 1 + + +def test_cancellation_wait_uses_subagent_polling_budget(monkeypatch): + """Cancelled parent waits on the existing subagent polling budget, not a fixed timeout.""" + config = _make_subagent_config() + events = [] + report_calls = [] + cleanup_calls = [] + sleep_count = 0 + result_polls = 0 + terminal_result = _make_result(FakeSubagentStatus.COMPLETED, result="done") + + def get_result(_: str): + nonlocal result_polls + result_polls += 1 + if result_polls < 5: + return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]) + return terminal_result + + async def cancel_then_continue(_: float) -> None: + nonlocal sleep_count + sleep_count += 1 + if sleep_count == 1: + raise asyncio.CancelledError + + def fake_report_subagent_usage(runtime, result): + report_calls.append((runtime, result)) + + async def fail_on_fixed_timeout(awaitable, *, timeout=None): + raise AssertionError(f"cancellation wait should not use fixed timeout={timeout}") + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_then_continue) + monkeypatch.setattr(task_tool_module.asyncio, "wait_for", fail_on_fixed_timeout) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage) + monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) + monkeypatch.setattr( + task_tool_module, + "cleanup_background_task", + lambda task_id: cleanup_calls.append(task_id), + ) + + with pytest.raises(asyncio.CancelledError): + _run_task_tool( + runtime=_make_runtime(), + description="执行任务", + prompt="cancel task", + subagent_type="general-purpose", + tool_call_id="tc-cancel-budget", + ) + + assert report_calls == [(_make_runtime(), terminal_result)] + assert cleanup_calls == ["tc-cancel-budget"] def test_cancellation_calls_request_cancel(monkeypatch): @@ -895,7 +969,6 @@ def test_cancellation_calls_request_cancel(monkeypatch): config = _make_subagent_config() events = [] cancel_requests = [] - scheduled_cleanup_coros = [] async def cancel_on_first_sleep(_: float) -> None: raise asyncio.CancelledError @@ -915,11 +988,6 @@ def test_cancellation_calls_request_cancel(monkeypatch): ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep) - monkeypatch.setattr( - task_tool_module.asyncio, - "create_task", - lambda coro: (coro.close(), scheduled_cleanup_coros.append(None))[-1] or _DummyScheduledTask(), - ) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -987,3 +1055,80 @@ def test_task_tool_returns_cancelled_message(monkeypatch): assert output == "Task cancelled by user." assert any(e.get("type") == "task_cancelled" for e in events) assert cleanup_calls == ["tc-poll-cancelled"] + + +def test_cancellation_reports_subagent_usage(monkeypatch): + """Verify cancellation handler waits (shielded) for subagent terminal state, + then reports the final token usage before re-raising CancelledError. + + The report must happen synchronously within the cancellation handler so + the parent worker's finally block sees the updated journal totals. + """ + config = _make_subagent_config() + events = [] + report_calls = [] + cleanup_calls = [] + + # Terminal result with token usage collected after cancellation processing + cancel_result = _make_result(FakeSubagentStatus.CANCELLED, error="Cancelled by user") + cancel_result.token_usage_records = [{"source_run_id": "sub-run-1", "caller": "subagent:gp", "input_tokens": 50, "output_tokens": 25, "total_tokens": 75}] + cancel_result.usage_reported = False + + poll_count = 0 + + def get_result(_: str): + nonlocal poll_count + poll_count += 1 + # Main loop polls 3 times (RUNNING each time to keep looping) + if poll_count <= 3: + running = _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]) + running.token_usage_records = [] + running.usage_reported = False + return running + # Shielded wait poll gets the terminal result + return cancel_result + + sleep_count = 0 + + async def cancel_on_third_sleep(_: float) -> None: + nonlocal sleep_count + sleep_count += 1 + if sleep_count == 3: + raise asyncio.CancelledError + + def fake_report_subagent_usage(runtime, result): + report_calls.append((runtime, result)) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_third_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage) + monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) + monkeypatch.setattr(task_tool_module, "request_cancel_background_task", lambda _: None) + monkeypatch.setattr( + task_tool_module, + "cleanup_background_task", + lambda task_id: cleanup_calls.append(task_id), + ) + + with pytest.raises(asyncio.CancelledError): + _run_task_tool( + runtime=_make_runtime(), + description="执行任务", + prompt="cancel me", + subagent_type="general-purpose", + tool_call_id="tc-cancel-report", + ) + + # _report_subagent_usage is called synchronously within the cancellation + # handler (after the shielded wait), before CancelledError is re-raised. + assert len(report_calls) == 1 + assert report_calls[0][1] is cancel_result + assert cleanup_calls == ["tc-cancel-report"] diff --git a/frontend/tests/unit/core/threads/api.test.ts b/frontend/tests/unit/core/threads/api.test.ts index d91a2bcdf..4d1268694 100644 --- a/frontend/tests/unit/core/threads/api.test.ts +++ b/frontend/tests/unit/core/threads/api.test.ts @@ -20,7 +20,11 @@ test("fetchThreadTokenUsage uses shared auth fetch without JSON GET headers", as total_tokens: 7, total_runs: 1, by_model: { unknown: { tokens: 7, runs: 1 } }, - by_caller: {}, + by_caller: { + lead_agent: 0, + subagent: 0, + middleware: 0, + }, }), });