From ca487578a40ad348f8b5b1013abe5719cee5325f Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Fri, 29 May 2026 22:59:26 +0800 Subject: [PATCH] feat(agent): add ToolOutputBudgetMiddleware for oversized tool output protection (#3303) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(agent): add ToolOutputBudgetMiddleware for oversized tool output protection Closes #3289. Adds a unified middleware that enforces per-result budgets on ALL tool outputs (MCP, sandbox, community, custom), preventing oversized external tool results from blowing the model context window. Design informed by claude-code (persistToolResult), hermes-agent (tool_result_storage), and pi (OutputAccumulator) — the three most mature implementations in production coding-agent frameworks. Key features: - Disk externalization: oversized outputs written to thread-local .tool-results/ directory, replaced with compact preview + file reference. Model can read full output via read_file with offset/limit. - Fallback truncation: head+tail truncation when disk is unavailable (no thread_data, write failure), ensuring the context is always protected. - read_file exemption: prevents persist-read-persist infinite loops (independently discovered by claude-code, hermes-agent, and pi). - Per-tool threshold overrides via config. - Line-boundary-aware truncation (no partial lines in previews). - Multimodal content passthrough (images/structured blocks skip budget). - Historical ToolMessage patching in wrap_model_call for checkpoint recovery scenarios. Related: #3222 (design RFC), #1844 (comprehensive context management), #3137 (write_file args compaction), #1677 (sandbox tool truncation). * test: add MCP content_and_artifact format coverage Add 5 tests for MCP tool output format (list of content blocks): - text content blocks are extracted and budgeted - multiple text blocks are joined and budgeted - image content blocks are skipped (multimodal passthrough) - mixed text+image blocks are skipped - small text blocks pass through unchanged Total test count: 59 (was 54). * fix(agent): address Codex review findings for ToolOutputBudgetMiddleware Three issues identified by Codex code review, all fixed: 1. `enabled` config field was unused — middleware now checks `config.enabled` and skips all processing when disabled. 2. `_build_fallback` could exceed `fallback_max_chars` — the marker text itself (~139 chars) was not deducted from the budget. Now pre-computes marker overhead and falls back to hard slice when max_chars is smaller than the marker. 3. Sync file I/O in async path — `awrap_tool_call` now delegates `_patch_result` to `asyncio.to_thread` to avoid blocking the event loop during disk writes. Tests updated to use realistic fallback_max_chars values (500+) that can accommodate the marker overhead, plus two new tests: - `test_result_never_exceeds_max_chars` (parametric across sizes) - `test_very_small_max_chars_does_not_crash` * fix(agent): address Copilot review — path traversal, async perf, shared config 1. Path traversal defense: sanitize tool_name via _sanitize_tool_name() (strips separators, .., absolute paths), validate storage_subdir is relative, and verify resolved filepath stays inside storage_dir. 2. Async hot-path optimization: add _needs_budget() cheap check before asyncio.to_thread offload — small outputs (99% of calls) skip the thread overhead entirely. 3. Replace shared module-level _DEFAULT_CONFIG with _default_config() factory to prevent cross-instance mutation of mutable fields. 12 new tests: TestSanitizeToolName (5), TestExternalizePathTraversal (3), TestNeedsBudget (4). * fix(agent): correct preview hint to match read_file actual API read_file uses start_line/end_line (1-indexed line numbers), not offset/limit. The previous wording was copied from hermes-agent which has a different read_file interface. * perf(agent): hoist hot-path imports, add model-call pre-scan (review #3303) Address maintainer review feedback: 1. Hoist inline imports to module level — `import asyncio` (was in awrap_tool_call hot path) and `from dataclasses import replace` (was in _patch_result) now live at module top. 2. Add a cheap pre-scan to _patch_model_messages so the historical message list is not rebuilt on every model call when nothing is oversized (the common case once results are budgeted at tool-call time). Also adds the same _needs_budget gate to the sync wrap_tool_call for symmetry with awrap_tool_call. The pre-scan is refactored into per-tool-aware helpers (_effective_trigger / _tool_message_over_budget) that mirror the exact trigger conditions in _budget_content — including tool_overrides — so the fast-path can never produce a false negative (silently skipping budgeting for a tool with a low per-tool threshold). 7 new regression tests lock the per-tool-override-through-pre-scan path and the model-call early return. --------- Co-authored-by: Willem Jiang --- .../tool_error_handling_middleware.py | 4 +- .../tool_output_budget_middleware.py | 489 ++++++++++ .../harness/deerflow/config/app_config.py | 2 + .../deerflow/config/tool_output_config.py | 62 ++ .../test_tool_error_handling_middleware.py | 10 +- .../test_tool_output_budget_middleware.py | 890 ++++++++++++++++++ config.example.yaml | 30 +- 7 files changed, 1481 insertions(+), 6 deletions(-) create mode 100644 backend/packages/harness/deerflow/agents/middlewares/tool_output_budget_middleware.py create mode 100644 backend/packages/harness/deerflow/config/tool_output_config.py create mode 100644 backend/tests/test_tool_output_budget_middleware.py diff --git a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py index ae3522454..fd590d5e9 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py @@ -77,9 +77,11 @@ def _build_runtime_middlewares( """Build shared base middlewares for agent execution.""" from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware + from deerflow.agents.middlewares.tool_output_budget_middleware import ToolOutputBudgetMiddleware from deerflow.sandbox.middleware import SandboxMiddleware middlewares: list[AgentMiddleware] = [ + ToolOutputBudgetMiddleware.from_app_config(app_config), ThreadDataMiddleware(lazy_init=lazy_init), SandboxMiddleware(lazy_init=lazy_init), ] @@ -87,7 +89,7 @@ def _build_runtime_middlewares( if include_uploads: from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware - middlewares.insert(1, UploadsMiddleware()) + middlewares.insert(2, UploadsMiddleware()) if include_dangling_tool_call_patch: from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware diff --git a/backend/packages/harness/deerflow/agents/middlewares/tool_output_budget_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/tool_output_budget_middleware.py new file mode 100644 index 000000000..3e5785e84 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/tool_output_budget_middleware.py @@ -0,0 +1,489 @@ +"""Middleware that enforces a per-result budget on tool outputs. + +Oversized tool results are persisted to disk and replaced with a compact +preview containing a file reference. When disk persistence is +unavailable the middleware falls back to head+tail truncation so the +model context is never blown by a single large tool return. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import uuid +from collections.abc import Awaitable, Callable +from dataclasses import replace as dc_replace +from typing import Any, override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest +from langgraph.types import Command + +from deerflow.config.tool_output_config import ToolOutputConfig + +logger = logging.getLogger(__name__) + + +def _default_config() -> ToolOutputConfig: + return ToolOutputConfig() + + +# --------------------------------------------------------------------------- +# Text helpers +# --------------------------------------------------------------------------- + + +def _message_text(content: Any) -> str | None: + """Extract a plain-text representation from a ToolMessage content field. + + Returns ``None`` for non-string / multimodal content so the caller + can skip budget enforcement (images, structured blocks, etc.). + """ + if isinstance(content, str): + return content + if content is None: + return None + if isinstance(content, list): + pieces: list[str] = [] + for part in content: + if isinstance(part, str): + pieces.append(part) + elif isinstance(part, dict) and isinstance(part.get("text"), str): + pieces.append(part["text"]) + else: + return None + return "\n".join(pieces) if pieces else None + return None + + +def _snap_to_line_boundary(text: str, pos: int) -> int: + """Return *pos* or the nearest preceding newline+1, whichever is closer. + + Used so that previews and truncations end on a complete line when + possible. If no newline exists in the second half of ``text[:pos]`` + the original *pos* is returned unchanged. + """ + if pos <= 0 or pos >= len(text): + return pos + half = pos // 2 + nl = text.rfind("\n", half, pos) + if nl >= 0: + return nl + 1 + return pos + + +# --------------------------------------------------------------------------- +# Disk persistence +# --------------------------------------------------------------------------- + +_EXT_MAP: dict[str, str] = { + "bash": "log", + "bash_tool": "log", + "web_fetch": "log", +} + + +def _sanitize_tool_name(name: str) -> str: + """Strip path separators and traversal components from a tool name.""" + base = os.path.basename(name) + safe = base.replace("..", "").replace("/", "_").replace("\\", "_") + return safe or "unknown" + + +def _externalize( + content: str, + *, + tool_name: str, + tool_call_id: str, + outputs_path: str, + storage_subdir: str, +) -> str | None: + """Write *content* to disk and return the virtual path, or ``None`` on failure.""" + if os.path.isabs(storage_subdir) or ".." in storage_subdir: + return None + storage_dir = os.path.join(outputs_path, storage_subdir) + try: + os.makedirs(storage_dir, exist_ok=True) + except OSError: + return None + + safe_name = _sanitize_tool_name(tool_name) + ext = _EXT_MAP.get(tool_name, "txt") + short_id = uuid.uuid4().hex[:12] + filename = f"{safe_name}-{short_id}.{ext}" + filepath = os.path.join(storage_dir, filename) + + if not os.path.abspath(filepath).startswith(os.path.abspath(storage_dir)): + return None + + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + except OSError: + return None + + virtual_base = "/mnt/user-data/outputs" + return f"{virtual_base}/{storage_subdir}/{filename}" + + +# --------------------------------------------------------------------------- +# Preview / fallback builders +# --------------------------------------------------------------------------- + + +def _build_preview( + content: str, + *, + tool_name: str, + virtual_path: str, + head_chars: int, + tail_chars: int, +) -> str: + """Build a preview with a file reference for externalized output.""" + total = len(content) + head_end = _snap_to_line_boundary(content, min(head_chars, total)) + tail_start = max(head_end, total - tail_chars) + tail_start_snapped = _snap_to_line_boundary(content, tail_start) + if tail_start_snapped > head_end: + tail_start = tail_start_snapped + + head = content[:head_end] + tail = content[tail_start:] if tail_start < total else "" + + omitted = total - len(head) - len(tail) + ref = f"\n\n[Full {tool_name} output saved to {virtual_path} ({total} chars, ~{total // 4} tokens). Use read_file with start_line and end_line to access specific sections. {omitted} chars omitted from this preview.]\n\n" + + parts = [head, ref] + if tail: + parts.append(tail) + return "".join(parts) + + +def _build_fallback( + content: str, + *, + tool_name: str, + max_chars: int, + head_chars: int, + tail_chars: int, +) -> str: + """Build a head+tail truncation when disk persistence is unavailable. + + The returned string is guaranteed to be no longer than *max_chars*. + """ + total = len(content) + if max_chars <= 0 or total <= max_chars: + return content + + marker_template = "\n\n[... {n} chars omitted from {tn} output. Persistent storage unavailable. Consider narrowing the query or using more specific parameters.]\n\n" + marker_overhead = len(marker_template.format(n=total, tn=tool_name)) + + if marker_overhead >= max_chars: + return content[:max_chars] + + budget = max_chars - marker_overhead + effective_head = min(head_chars, budget) + effective_tail = min(tail_chars, max(0, budget - effective_head)) + + head_end = _snap_to_line_boundary(content, min(effective_head, total)) + tail_start = max(head_end, total - effective_tail) + tail_start_snapped = _snap_to_line_boundary(content, tail_start) + if tail_start_snapped > head_end: + tail_start = tail_start_snapped + + head = content[:head_end] + tail = content[tail_start:] if tail_start < total else "" + omitted = total - len(head) - len(tail) + + marker = marker_template.format(n=omitted, tn=tool_name) + + parts = [head, marker] + if tail: + parts.append(tail) + return "".join(parts) + + +# --------------------------------------------------------------------------- +# Core budget logic +# --------------------------------------------------------------------------- + + +def _resolve_outputs_path(request: ToolCallRequest) -> str | None: + """Best-effort extraction of the thread outputs path.""" + runtime = getattr(request, "runtime", None) + if runtime is None: + return None + state = getattr(runtime, "state", None) + if state is None: + return None + thread_data = state.get("thread_data") + if not isinstance(thread_data, dict): + return None + outputs_path = thread_data.get("outputs_path") + return outputs_path if isinstance(outputs_path, str) else None + + +def _budget_content( + content: str, + *, + tool_name: str, + tool_call_id: str, + outputs_path: str | None, + config: ToolOutputConfig, +) -> str | None: + """Apply budget to *content*. Returns ``None`` if no change needed.""" + threshold = config.tool_overrides.get(tool_name, config.externalize_min_chars) + if threshold <= 0 and config.fallback_max_chars <= 0: + return None + if len(content) <= threshold and len(content) <= config.fallback_max_chars: + return None + + if threshold > 0 and len(content) > threshold and outputs_path: + virtual_path = _externalize( + content, + tool_name=tool_name, + tool_call_id=tool_call_id, + outputs_path=outputs_path, + storage_subdir=config.storage_subdir, + ) + if virtual_path is not None: + logger.info( + "Externalized %s output (%d chars) to %s", + tool_name, + len(content), + virtual_path, + ) + return _build_preview( + content, + tool_name=tool_name, + virtual_path=virtual_path, + head_chars=config.preview_head_chars, + tail_chars=config.preview_tail_chars, + ) + + if config.fallback_max_chars > 0 and len(content) > config.fallback_max_chars: + logger.warning( + "Fallback-truncating %s output: %d chars → %d max", + tool_name, + len(content), + config.fallback_max_chars, + ) + return _build_fallback( + content, + tool_name=tool_name, + max_chars=config.fallback_max_chars, + head_chars=config.fallback_head_chars, + tail_chars=config.fallback_tail_chars, + ) + + return None + + +# --------------------------------------------------------------------------- +# Result patchers +# --------------------------------------------------------------------------- + + +def _patch_tool_message(msg: ToolMessage, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage: + """Apply budget to a single ToolMessage. Returns the original if unchanged.""" + tool_name = msg.name or "unknown" + if tool_name in config.exempt_tools: + return msg + + text = _message_text(msg.content) + if text is None: + return msg + + replacement = _budget_content( + text, + tool_name=tool_name, + tool_call_id=msg.tool_call_id or "", + outputs_path=outputs_path, + config=config, + ) + if replacement is None: + return msg + + update: dict[str, Any] = {"content": replacement} + if getattr(msg, "response_metadata", None): + update["response_metadata"] = dict(msg.response_metadata) + if getattr(msg, "additional_kwargs", None): + update["additional_kwargs"] = dict(msg.additional_kwargs) + return msg.model_copy(update=update) + + +def _effective_trigger(tool_name: str, config: ToolOutputConfig) -> int: + """Smallest content length that could trigger budgeting for *tool_name*. + + Mirrors the trigger conditions in :func:`_budget_content` (per-tool + externalize threshold OR global fallback), so the pre-scan never produces + a false negative. Returns ``-1`` when nothing could ever trigger. + """ + candidates: list[int] = [] + externalize = config.tool_overrides.get(tool_name, config.externalize_min_chars) + if externalize > 0: + candidates.append(externalize) + if config.fallback_max_chars > 0: + candidates.append(config.fallback_max_chars) + return min(candidates) if candidates else -1 + + +def _tool_message_over_budget(msg: ToolMessage, config: ToolOutputConfig) -> bool: + """Cheap, per-tool-aware check: is this ToolMessage non-exempt and over its trigger?""" + if (msg.name or "") in config.exempt_tools: + return False + trigger = _effective_trigger(msg.name or "", config) + if trigger < 0: + return False + text = _message_text(msg.content) + return text is not None and len(text) > trigger + + +def _needs_budget(result: ToolMessage | Command, config: ToolOutputConfig) -> bool: + """Fast check whether *result* could need budgeting (avoids thread offload for small outputs).""" + if isinstance(result, ToolMessage): + return _tool_message_over_budget(result, config) + update = getattr(result, "update", None) + if isinstance(update, dict): + for msg in update.get("messages", []): + if isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config): + return True + return False + + +def _patch_result(result: ToolMessage | Command, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage | Command: + """Apply budget to a tool call result (ToolMessage or Command).""" + if isinstance(result, ToolMessage): + return _patch_tool_message(result, config, outputs_path) + + update = getattr(result, "update", None) + if not isinstance(update, dict): + return result + + messages = update.get("messages") + if not isinstance(messages, list): + return result + + new_messages: list[Any] = [] + changed = False + for msg in messages: + if isinstance(msg, ToolMessage): + patched = _patch_tool_message(msg, config, outputs_path) + if patched is not msg: + changed = True + new_messages.append(patched) + else: + new_messages.append(msg) + + if not changed: + return result + + return dc_replace(result, update={**update, "messages": new_messages}) + + +def _patch_model_messages(messages: list[Any], config: ToolOutputConfig) -> list[Any] | None: + """Apply budget to historical ToolMessages in a model request. Returns ``None`` if unchanged. + + A cheap pre-scan bails out before allocating a new list when no historical + ToolMessage exceeds the budget — the common case once every result has + already been budgeted at tool-call time, so a long history is not rebuilt + on every model call. + """ + if not any(isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config) for msg in messages): + return None + + updated: list[Any] = [] + changed = False + for msg in messages: + if isinstance(msg, ToolMessage): + patched = _patch_tool_message(msg, config, outputs_path=None) + if patched is not msg: + changed = True + updated.append(patched) + else: + updated.append(msg) + return updated if changed else None + + +# --------------------------------------------------------------------------- +# Middleware class +# --------------------------------------------------------------------------- + + +class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]): + """Enforce per-result budget on tool outputs via externalization or truncation.""" + + def __init__(self, config: ToolOutputConfig | None = None) -> None: + super().__init__() + self._config = config if config is not None else _default_config() + + @classmethod + def from_app_config(cls, app_config: Any) -> ToolOutputBudgetMiddleware: + tool_output = getattr(app_config, "tool_output", None) + if isinstance(tool_output, ToolOutputConfig): + return cls(config=tool_output) + return cls() + + # -- tool call hooks --------------------------------------------------- + + @override + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + result = handler(request) + if not self._config.enabled: + return result + if not _needs_budget(result, self._config): + return result + outputs_path = _resolve_outputs_path(request) + return _patch_result(result, self._config, outputs_path) + + @override + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + result = await handler(request) + if not self._config.enabled: + return result + if not _needs_budget(result, self._config): + return result + outputs_path = _resolve_outputs_path(request) + return await asyncio.to_thread(_patch_result, result, self._config, outputs_path) + + # -- model call hooks (historical message truncation) ------------------ + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + if self._config.enabled: + messages = getattr(request, "messages", None) + if isinstance(messages, list): + patched = _patch_model_messages(messages, self._config) + if patched is not None: + request = request.override(messages=patched) + return handler(request) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + if self._config.enabled: + messages = getattr(request, "messages", None) + if isinstance(messages, list): + patched = _patch_model_messages(messages, self._config) + if patched is not None: + request = request.override(messages=patched) + return await handler(request) diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index 931c95757..8fcc564a8 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -30,6 +30,7 @@ from deerflow.config.summarization_config import SummarizationConfig, load_summa from deerflow.config.title_config import TitleConfig, load_title_config_from_dict from deerflow.config.token_usage_config import TokenUsageConfig from deerflow.config.tool_config import ToolConfig, ToolGroupConfig +from deerflow.config.tool_output_config import ToolOutputConfig from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict load_dotenv() @@ -93,6 +94,7 @@ class AppConfig(BaseModel): skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration") skill_evolution: SkillEvolutionConfig = Field(default_factory=SkillEvolutionConfig, description="Agent-managed skill evolution configuration") extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)") + tool_output: ToolOutputConfig = Field(default_factory=ToolOutputConfig, description="Tool output budget protection configuration") tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration") title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration") summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration") diff --git a/backend/packages/harness/deerflow/config/tool_output_config.py b/backend/packages/harness/deerflow/config/tool_output_config.py new file mode 100644 index 000000000..5e8c24342 --- /dev/null +++ b/backend/packages/harness/deerflow/config/tool_output_config.py @@ -0,0 +1,62 @@ +"""Configuration for tool output budget protection.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class ToolOutputConfig(BaseModel): + """Config section for tool-result output budget enforcement. + + When a tool returns more than ``externalize_min_chars`` characters, + the full output is persisted to disk and replaced with a compact + preview + file reference. If disk persistence is unavailable the + output falls back to head+tail truncation. + """ + + enabled: bool = Field( + default=True, + description="Enable the tool output budget middleware.", + ) + externalize_min_chars: int = Field( + default=12_000, + ge=0, + description="Character threshold to trigger disk externalization. Outputs below this pass through unchanged. Set to 0 to disable externalization (fallback truncation still applies when output exceeds fallback_max_chars).", + ) + preview_head_chars: int = Field( + default=2_000, + ge=0, + description="Characters to keep from the head of the output in the preview.", + ) + preview_tail_chars: int = Field( + default=1_000, + ge=0, + description="Characters to keep from the tail of the output in the preview.", + ) + fallback_max_chars: int = Field( + default=30_000, + ge=0, + description="Maximum characters when disk persistence is unavailable. 0 disables fallback truncation.", + ) + fallback_head_chars: int = Field( + default=8_000, + ge=0, + description="Head characters for fallback truncation.", + ) + fallback_tail_chars: int = Field( + default=3_000, + ge=0, + description="Tail characters for fallback truncation.", + ) + storage_subdir: str = Field( + default=".tool-results", + description="Subdirectory under the thread outputs path for persisted tool results.", + ) + exempt_tools: list[str] = Field( + default_factory=lambda: ["read_file", "read_file_tool"], + description="Tool names exempt from budget enforcement (prevents persist→read→persist loops).", + ) + tool_overrides: dict[str, int] = Field( + default_factory=dict, + description="Per-tool externalize_min_chars overrides. Keys are tool names, values are char thresholds. Use 0 to disable externalization for a specific tool.", + ) diff --git a/backend/tests/test_tool_error_handling_middleware.py b/backend/tests/test_tool_error_handling_middleware.py index 28c59a9ad..c9b835527 100644 --- a/backend/tests/test_tool_error_handling_middleware.py +++ b/backend/tests/test_tool_error_handling_middleware.py @@ -134,12 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False) assert captured["app_config"] is app_config - # 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling, - # SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware - # (enabled by default — see SafetyFinishReasonConfig). + # 7 baseline (ToolOutputBudget, ThreadData, Sandbox, DanglingToolCall, + # LLMErrorHandling, SandboxAudit, ToolErrorHandling) + # + 1 SafetyFinishReasonMiddleware (enabled by default). from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware + from deerflow.agents.middlewares.tool_output_budget_middleware import ToolOutputBudgetMiddleware - assert len(middlewares) == 7 + assert len(middlewares) == 8 + assert isinstance(middlewares[0], ToolOutputBudgetMiddleware) assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares) assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware) diff --git a/backend/tests/test_tool_output_budget_middleware.py b/backend/tests/test_tool_output_budget_middleware.py new file mode 100644 index 000000000..d6ec51052 --- /dev/null +++ b/backend/tests/test_tool_output_budget_middleware.py @@ -0,0 +1,890 @@ +"""Comprehensive tests for ToolOutputBudgetMiddleware. + +Covers: pass-through, disk externalization, fallback truncation, UTF-8 +boundaries, Command results, model-request history patching, config +variations, exempt tools, per-tool overrides, edge cases, and both +sync/async code paths. +""" + +from __future__ import annotations + +import os +import tempfile +from types import SimpleNamespace + +import pytest +from langchain.agents.middleware.types import ModelRequest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langgraph.types import Command + +from deerflow.agents.middlewares.tool_output_budget_middleware import ( + ToolOutputBudgetMiddleware, + _build_fallback, + _build_preview, + _effective_trigger, + _externalize, + _message_text, + _needs_budget, + _patch_model_messages, + _sanitize_tool_name, + _snap_to_line_boundary, + _tool_message_over_budget, +) +from deerflow.config.app_config import AppConfig +from deerflow.config.sandbox_config import SandboxConfig +from deerflow.config.tool_output_config import ToolOutputConfig + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_request(tool_name: str = "remote_executor", tool_call_id: str = "tc-1", outputs_path: str | None = None) -> SimpleNamespace: + thread_data = {"outputs_path": outputs_path} if outputs_path else None + state = {"thread_data": thread_data} if thread_data else {} + runtime = SimpleNamespace(state=state) + return SimpleNamespace( + tool_call={"name": tool_name, "id": tool_call_id}, + runtime=runtime, + ) + + +def _tm(content: str = "ok", name: str = "tool", tool_call_id: str = "tc-1") -> ToolMessage: + return ToolMessage(content=content, name=name, tool_call_id=tool_call_id) + + +# =========================================================================== +# Unit tests for helper functions +# =========================================================================== + + +class TestMessageText: + def test_string_content(self): + assert _message_text("hello") == "hello" + + def test_none_content(self): + assert _message_text(None) is None + + def test_list_of_strings(self): + assert _message_text(["a", "b"]) == "a\nb" + + def test_list_of_text_dicts(self): + assert _message_text([{"text": "x"}, {"text": "y"}]) == "x\ny" + + def test_list_with_image_returns_none(self): + assert _message_text([{"type": "image", "data": "..."}]) is None + + def test_empty_list(self): + assert _message_text([]) is None + + def test_non_string_non_list(self): + assert _message_text(42) is None + + +class TestSnapToLineBoundary: + def test_snaps_to_newline(self): + text = "line1\nline2\nline3" + pos = 14 # inside "line3" + result = _snap_to_line_boundary(text, pos) + assert text[result - 1] == "\n" + + def test_no_snap_when_no_newline_in_range(self): + text = "abcdefghij" + assert _snap_to_line_boundary(text, 8) == 8 + + def test_zero_pos(self): + assert _snap_to_line_boundary("abc", 0) == 0 + + def test_pos_beyond_length(self): + assert _snap_to_line_boundary("abc", 10) == 10 + + +class TestExternalize: + def test_writes_file_and_returns_virtual_path(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = _externalize( + "full content here", + tool_name="bash", + tool_call_id="tc-1", + outputs_path=tmpdir, + storage_subdir=".tool-results", + ) + assert path is not None + assert path.startswith("/mnt/user-data/outputs/.tool-results/bash-") + assert path.endswith(".log") + + # Verify actual file on disk + storage_dir = os.path.join(tmpdir, ".tool-results") + files = os.listdir(storage_dir) + assert len(files) == 1 + with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f: + assert f.read() == "full content here" + + def test_returns_none_on_invalid_path(self): + path = _externalize( + "data", + tool_name="test", + tool_call_id="tc-1", + outputs_path="/nonexistent/path/that/should/not/exist", + storage_subdir=".tool-results", + ) + assert path is None + + def test_txt_extension_for_unknown_tool(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = _externalize( + "data", + tool_name="unknown_tool", + tool_call_id="tc-1", + outputs_path=tmpdir, + storage_subdir=".tool-results", + ) + assert path is not None + assert path.endswith(".txt") + + +class TestSanitizeToolName: + def test_strips_path_separators(self): + assert _sanitize_tool_name("../../etc/passwd") == "passwd" + + def test_strips_backslashes(self): + result = _sanitize_tool_name("..\\..\\windows\\system32") + assert ".." not in result + assert "/" not in result + + def test_normal_name_unchanged(self): + assert _sanitize_tool_name("bash") == "bash" + + def test_empty_becomes_unknown(self): + assert _sanitize_tool_name("") == "unknown" + + def test_dots_only_becomes_unknown(self): + assert _sanitize_tool_name("..") == "unknown" + + +class TestExternalizePathTraversal: + def test_traversal_tool_name_is_sanitized(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = _externalize( + "data", + tool_name="../../etc/passwd", + tool_call_id="tc-1", + outputs_path=tmpdir, + storage_subdir=".tool-results", + ) + assert path is not None + assert "passwd-" in path + assert "../" not in path + + def test_absolute_storage_subdir_rejected(self): + path = _externalize( + "data", + tool_name="tool", + tool_call_id="tc-1", + outputs_path="/tmp", + storage_subdir="/etc/evil", + ) + assert path is None + + def test_traversal_storage_subdir_rejected(self): + path = _externalize( + "data", + tool_name="tool", + tool_call_id="tc-1", + outputs_path="/tmp", + storage_subdir="../../../etc", + ) + assert path is None + + +class TestNeedsBudget: + def test_small_output_does_not_need_budget(self): + config = ToolOutputConfig(externalize_min_chars=1000) + msg = _tm("small", name="tool") + assert _needs_budget(msg, config) is False + + def test_large_output_needs_budget(self): + config = ToolOutputConfig(externalize_min_chars=50) + msg = _tm("x" * 100, name="tool") + assert _needs_budget(msg, config) is True + + def test_exempt_tool_does_not_need_budget(self): + config = ToolOutputConfig(externalize_min_chars=10) + msg = _tm("x" * 100, name="read_file") + assert _needs_budget(msg, config) is False + + def test_multimodal_does_not_need_budget(self): + config = ToolOutputConfig(externalize_min_chars=10) + msg = ToolMessage(content=[{"type": "image", "data": "x" * 100}], name="tool", tool_call_id="tc-1") + assert _needs_budget(msg, config) is False + + +class TestBuildPreview: + def test_contains_head_and_tail_and_reference(self): + content = "HEAD_" + "x" * 5000 + "_TAIL" + preview = _build_preview( + content, + tool_name="bash", + virtual_path="/mnt/test/bash-abc.log", + head_chars=100, + tail_chars=50, + ) + assert preview.startswith("HEAD_") + assert "_TAIL" in preview + assert "/mnt/test/bash-abc.log" in preview + assert "read_file" in preview + assert "start_line and end_line" in preview + + def test_reports_total_chars(self): + content = "a" * 10000 + preview = _build_preview( + content, + tool_name="web_search", + virtual_path="/mnt/test/file.txt", + head_chars=200, + tail_chars=100, + ) + assert "10000 chars" in preview + + +class TestBuildFallback: + def test_short_content_unchanged(self): + assert _build_fallback("short", tool_name="t", max_chars=100, head_chars=50, tail_chars=50) == "short" + + def test_zero_max_disables(self): + content = "a" * 1000 + assert _build_fallback(content, tool_name="t", max_chars=0, head_chars=50, tail_chars=50) == content + + def test_truncates_long_content(self): + content = "H" * 5000 + "M" * 20000 + "T" * 5000 + result = _build_fallback(content, tool_name="bash", max_chars=12000, head_chars=6000, tail_chars=3000) + assert len(result) < len(content) + assert "omitted from bash output" in result + assert "Persistent storage unavailable" in result + + def test_preserves_head_and_tail(self): + content = "HEADSTART" + "x" * 50000 + "TAILEND" + result = _build_fallback(content, tool_name="t", max_chars=20000, head_chars=10000, tail_chars=5000) + assert result.startswith("HEADSTART") + assert "TAILEND" in result + + def test_result_never_exceeds_max_chars(self): + """The marker itself has non-zero length; total must still respect max_chars.""" + for max_chars in [200, 500, 1000, 5000, 20000]: + content = "x" * 50000 + result = _build_fallback(content, tool_name="long_tool_name", max_chars=max_chars, head_chars=max_chars // 2, tail_chars=max_chars // 4) + assert len(result) <= max_chars, f"max_chars={max_chars}: got {len(result)}" + + def test_very_small_max_chars_does_not_crash(self): + content = "x" * 1000 + result = _build_fallback(content, tool_name="t", max_chars=50, head_chars=20, tail_chars=10) + assert len(result) <= 50 + + +# =========================================================================== +# Middleware integration tests — wrap_tool_call +# =========================================================================== + + +class TestWrapToolCallPassThrough: + def test_small_output_passes_through(self): + mw = ToolOutputBudgetMiddleware(config=ToolOutputConfig(externalize_min_chars=1000)) + msg = _tm("small output", name="bash") + result = mw.wrap_tool_call(_make_request(), lambda _: msg) + assert result is msg + + def test_disabled_middleware_passes_through(self): + mw = ToolOutputBudgetMiddleware(config=ToolOutputConfig(enabled=False, externalize_min_chars=10, fallback_max_chars=20)) + msg = _tm("x" * 50000, name="bash") + result = mw.wrap_tool_call(_make_request(), lambda _: msg) + assert result is msg + + +class TestWrapToolCallExternalize: + def test_oversized_output_externalized_to_disk(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = ToolOutputConfig(externalize_min_chars=100, preview_head_chars=50, preview_tail_chars=30) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 500 + msg = _tm(content, name="remote_executor") + req = _make_request(outputs_path=tmpdir) + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert isinstance(result, ToolMessage) + assert result is not msg + assert "Full remote_executor output saved to" in result.content + assert "read_file" in result.content + assert result.tool_call_id == "tc-1" + + # Verify file was written + storage_dir = os.path.join(tmpdir, ".tool-results") + assert os.path.isdir(storage_dir) + files = os.listdir(storage_dir) + assert len(files) == 1 + with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f: + assert f.read() == content + + def test_preview_contains_head_and_tail(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10) + mw = ToolOutputBudgetMiddleware(config=config) + content = "HEADPART_" + "m" * 200 + "_TAILPART" + msg = _tm(content, name="web_search") + req = _make_request(outputs_path=tmpdir) + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert result.content.startswith("HEADPART_") + assert "_TAILPART" in result.content + + +class TestWrapToolCallFallback: + def test_fallback_when_no_outputs_path(self): + config = ToolOutputConfig( + externalize_min_chars=50, + fallback_max_chars=200, + fallback_head_chars=80, + fallback_tail_chars=40, + ) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 500 + msg = _tm(content, name="mcp_tool") + req = _make_request(outputs_path=None) + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert isinstance(result, ToolMessage) + assert result is not msg + assert "omitted from mcp_tool output" in result.content + assert "Persistent storage unavailable" in result.content + assert len(result.content) < len(content) + + def test_fallback_when_disk_write_fails(self): + config = ToolOutputConfig( + externalize_min_chars=50, + fallback_max_chars=200, + fallback_head_chars=80, + fallback_tail_chars=40, + ) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 500 + msg = _tm(content, name="tool") + req = _make_request(outputs_path="/nonexistent/impossible/path") + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert isinstance(result, ToolMessage) + assert "omitted from tool output" in result.content + + +class TestWrapToolCallExemption: + def test_read_file_exempt(self): + config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 100 + msg = _tm(content, name="read_file") + + result = mw.wrap_tool_call(_make_request(tool_name="read_file"), lambda _: msg) + + assert result is msg + + def test_read_file_tool_exempt(self): + config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 100 + msg = _tm(content, name="read_file_tool") + + result = mw.wrap_tool_call(_make_request(tool_name="read_file_tool"), lambda _: msg) + + assert result is msg + + def test_custom_exempt_tool(self): + config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50, exempt_tools=["my_tool"]) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 100 + msg = _tm(content, name="my_tool") + + result = mw.wrap_tool_call(_make_request(tool_name="my_tool"), lambda _: msg) + + assert result is msg + + +class TestWrapToolCallPerToolOverride: + def test_per_tool_threshold(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = ToolOutputConfig( + externalize_min_chars=50000, # global: high + tool_overrides={"sensitive_tool": 100}, # override: low + ) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 500 + msg = _tm(content, name="sensitive_tool") + req = _make_request(tool_name="sensitive_tool", outputs_path=tmpdir) + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert result is not msg + assert "Full sensitive_tool output saved to" in result.content + + def test_per_tool_zero_disables_externalization(self): + config = ToolOutputConfig( + externalize_min_chars=50, + tool_overrides={"bash": 0}, + fallback_max_chars=200, + fallback_head_chars=80, + fallback_tail_chars=40, + ) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 500 + msg = _tm(content, name="bash") + # Even with outputs_path, externalization disabled for bash + req = _make_request(tool_name="bash", outputs_path="/tmp/test") + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert isinstance(result, ToolMessage) + # Should use fallback instead of externalization + assert "Persistent storage unavailable" in result.content or "omitted" in result.content + + +class TestWrapToolCallCommand: + def test_command_messages_are_patched(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10) + mw = ToolOutputBudgetMiddleware(config=config) + tool_msg = _tm("x" * 200, name="present_files") + command = Command(update={"messages": [tool_msg], "artifacts": ["/mnt/report.html"]}) + req = _make_request(tool_name="present_files", outputs_path=tmpdir) + + result = mw.wrap_tool_call(req, lambda _: command) + + assert isinstance(result, Command) + assert result is not command + assert result.update["artifacts"] == ["/mnt/report.html"] + new_msg = result.update["messages"][0] + assert isinstance(new_msg, ToolMessage) + assert "Full present_files output saved to" in new_msg.content + + def test_command_without_messages_unchanged(self): + config = ToolOutputConfig(externalize_min_chars=10) + mw = ToolOutputBudgetMiddleware(config=config) + command = Command(update={"key": "value"}) + result = mw.wrap_tool_call(_make_request(), lambda _: command) + assert result is command + + +class TestWrapToolCallEdgeCases: + def test_none_content_passes_through(self): + config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20) + mw = ToolOutputBudgetMiddleware(config=config) + msg = ToolMessage(content=None, name="tool", tool_call_id="tc-1") + + result = mw.wrap_tool_call(_make_request(), lambda _: msg) + + assert result is msg + + def test_empty_string_passes_through(self): + config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20) + mw = ToolOutputBudgetMiddleware(config=config) + msg = _tm("", name="tool") + + result = mw.wrap_tool_call(_make_request(), lambda _: msg) + + assert result is msg + + def test_multimodal_content_skipped(self): + config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20) + mw = ToolOutputBudgetMiddleware(config=config) + content = [{"type": "image", "data": "x" * 100}] + msg = ToolMessage(content=content, name="tool", tool_call_id="tc-1") + + result = mw.wrap_tool_call(_make_request(), lambda _: msg) + + assert result is msg + + def test_exactly_at_threshold_passes_through(self): + config = ToolOutputConfig(externalize_min_chars=100, fallback_max_chars=100) + mw = ToolOutputBudgetMiddleware(config=config) + msg = _tm("x" * 100, name="tool") + + result = mw.wrap_tool_call(_make_request(), lambda _: msg) + + assert result is msg + + def test_one_char_over_threshold_triggers(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = ToolOutputConfig(externalize_min_chars=100) + mw = ToolOutputBudgetMiddleware(config=config) + msg = _tm("x" * 101, name="tool") + req = _make_request(outputs_path=tmpdir) + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert result is not msg + + def test_chinese_content_preserved(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10) + mw = ToolOutputBudgetMiddleware(config=config) + content = "你好世界" * 50 + msg = _tm(content, name="tool") + req = _make_request(outputs_path=tmpdir) + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert isinstance(result, ToolMessage) + # File should contain the full Chinese content + storage_dir = os.path.join(tmpdir, ".tool-results") + files = os.listdir(storage_dir) + with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f: + assert f.read() == content + + def test_no_runtime_state_uses_fallback(self): + config = ToolOutputConfig( + externalize_min_chars=50, + fallback_max_chars=500, + fallback_head_chars=100, + fallback_tail_chars=50, + ) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 1000 + msg = _tm(content, name="tool") + req = SimpleNamespace( + tool_call={"name": "tool", "id": "tc-1"}, + runtime=None, + ) + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert isinstance(result, ToolMessage) + assert "omitted" in result.content + assert len(result.content) <= 500 + + +# =========================================================================== +# MCP content_and_artifact format tests +# =========================================================================== + + +class TestMCPContentAndArtifact: + """MCP tools return content as list of content blocks, not plain strings.""" + + def test_text_content_blocks_are_budgeted(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10) + mw = ToolOutputBudgetMiddleware(config=config) + content = [{"type": "text", "text": "x" * 200}] + msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mcp") + req = _make_request(tool_name="mcp_tool", outputs_path=tmpdir) + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert result is not msg + assert isinstance(result.content, str) + assert "Full mcp_tool output saved to" in result.content + assert result.tool_call_id == "tc-mcp" + + def test_multiple_text_blocks_joined_and_budgeted(self): + config = ToolOutputConfig(externalize_min_chars=50, fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50) + mw = ToolOutputBudgetMiddleware(config=config) + content = [{"type": "text", "text": "a" * 300}, {"type": "text", "text": "b" * 300}] + msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mcp2") + req = _make_request(tool_name="mcp_tool") + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert result is not msg + assert "omitted" in result.content + + def test_image_content_blocks_are_skipped(self): + config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20) + mw = ToolOutputBudgetMiddleware(config=config) + content = [{"type": "image", "data": "base64data" * 100}] + msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-img") + req = _make_request(tool_name="mcp_tool") + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert result is msg + + def test_mixed_text_and_image_blocks_are_skipped(self): + config = ToolOutputConfig(externalize_min_chars=10) + mw = ToolOutputBudgetMiddleware(config=config) + content = [{"type": "text", "text": "x" * 100}, {"type": "image", "data": "base64"}] + msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mix") + req = _make_request(tool_name="mcp_tool") + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert result is msg + + def test_small_text_blocks_pass_through(self): + config = ToolOutputConfig(externalize_min_chars=1000) + mw = ToolOutputBudgetMiddleware(config=config) + content = [{"type": "text", "text": "small result"}] + msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-sm") + req = _make_request(tool_name="mcp_tool") + + result = mw.wrap_tool_call(req, lambda _: msg) + + assert result is msg + + +# =========================================================================== +# Async path tests +# =========================================================================== + + +class TestAsyncPaths: + @pytest.mark.anyio + async def test_async_tool_call_externalized(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10) + mw = ToolOutputBudgetMiddleware(config=config) + content = "x" * 200 + msg = _tm(content, name="async_tool") + req = _make_request(tool_name="async_tool", outputs_path=tmpdir) + + async def handler(_): + return msg + + result = await mw.awrap_tool_call(req, handler) + + assert isinstance(result, ToolMessage) + assert result is not msg + assert "Full async_tool output saved to" in result.content + + @pytest.mark.anyio + async def test_async_model_call_patches_history(self): + config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50) + mw = ToolOutputBudgetMiddleware(config=config) + oversized = _tm("h" * 1000, name="tool", tool_call_id="tc-h") + request = ModelRequest(model=None, messages=[oversized], tools=[], state={}) + captured: dict[str, ModelRequest] = {} + + async def handler(req): + captured["request"] = req + return [] + + await mw.awrap_model_call(request, handler) + + forwarded = captured["request"] + assert forwarded is not request + msg = forwarded.messages[0] + assert isinstance(msg, ToolMessage) + assert "omitted" in msg.content + + +# =========================================================================== +# wrap_model_call — historical message patching +# =========================================================================== + + +class TestWrapModelCall: + def test_oversized_historical_messages_truncated(self): + config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50) + mw = ToolOutputBudgetMiddleware(config=config) + oversized = _tm("q" * 1000, name="tool", tool_call_id="tc-q") + request = ModelRequest(model=None, messages=[oversized], tools=[], state={}) + captured: dict[str, ModelRequest] = {} + + def handler(req): + captured["request"] = req + return [] + + mw.wrap_model_call(request, handler) + + forwarded = captured["request"] + assert forwarded is not request + msg = forwarded.messages[0] + assert isinstance(msg, ToolMessage) + assert "omitted" in msg.content + assert len(msg.content) < len(oversized.content) + 150 + + def test_small_historical_messages_unchanged(self): + config = ToolOutputConfig(fallback_max_chars=1000) + mw = ToolOutputBudgetMiddleware(config=config) + small = _tm("small", name="tool") + request = ModelRequest(model=None, messages=[small], tools=[], state={}) + captured: dict[str, ModelRequest] = {} + + def handler(req): + captured["request"] = req + return [] + + mw.wrap_model_call(request, handler) + + assert captured["request"] is request + + def test_exempt_tools_in_history_unchanged(self): + config = ToolOutputConfig(fallback_max_chars=50) + mw = ToolOutputBudgetMiddleware(config=config) + read_msg = _tm("x" * 200, name="read_file", tool_call_id="tc-r") + request = ModelRequest(model=None, messages=[read_msg], tools=[], state={}) + captured: dict[str, ModelRequest] = {} + + def handler(req): + captured["request"] = req + return [] + + mw.wrap_model_call(request, handler) + + assert captured["request"] is request + + def test_non_tool_messages_preserved(self): + config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50) + mw = ToolOutputBudgetMiddleware(config=config) + human = HumanMessage(content="x" * 200) + ai = AIMessage(content="y" * 200) + oversized_tool = _tm("z" * 1000, name="tool") + request = ModelRequest(model=None, messages=[human, ai, oversized_tool], tools=[], state={}) + captured: dict[str, ModelRequest] = {} + + def handler(req): + captured["request"] = req + return [] + + mw.wrap_model_call(request, handler) + + msgs = captured["request"].messages + assert msgs[0] is human + assert msgs[1] is ai + assert isinstance(msgs[2], ToolMessage) + assert "omitted" in msgs[2].content + + +# =========================================================================== +# Config integration +# =========================================================================== + + +class TestFromAppConfig: + def test_from_app_config_with_tool_output(self): + config = AppConfig( + sandbox=SandboxConfig(use="test"), + tool_output={"externalize_min_chars": 5000, "preview_head_chars": 500}, + ) + mw = ToolOutputBudgetMiddleware.from_app_config(config) + assert mw._config.externalize_min_chars == 5000 + assert mw._config.preview_head_chars == 500 + + def test_from_app_config_defaults(self): + config = AppConfig(sandbox=SandboxConfig(use="test")) + mw = ToolOutputBudgetMiddleware.from_app_config(config) + assert mw._config.externalize_min_chars == 12000 + + +class TestPatchModelMessages: + def test_returns_none_when_no_changes(self): + config = ToolOutputConfig(fallback_max_chars=1000) + messages = [_tm("short", name="tool")] + assert _patch_model_messages(messages, config) is None + + def test_patches_oversized_messages(self): + config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50) + messages = [_tm("x" * 1000, name="tool")] + result = _patch_model_messages(messages, config) + assert result is not None + assert len(result) == 1 + assert "omitted" in result[0].content + + +# =========================================================================== +# Pre-scan helpers (_effective_trigger / _tool_message_over_budget / _needs_budget) +# These guard the fast-path optimization — a false negative here is a real bug +# (budgeting silently skipped), so per-tool overrides must be honored. +# =========================================================================== + + +class TestPreScanHelpers: + def test_effective_trigger_uses_global_externalize(self): + config = ToolOutputConfig(externalize_min_chars=12000, fallback_max_chars=30000) + # smallest of the two thresholds wins + assert _effective_trigger("any_tool", config) == 12000 + + def test_effective_trigger_respects_per_tool_override(self): + config = ToolOutputConfig(externalize_min_chars=50000, fallback_max_chars=0, tool_overrides={"sensitive": 100}) + assert _effective_trigger("sensitive", config) == 100 + # other tools fall back to the (high) global + assert _effective_trigger("other", config) == 50000 + + def test_effective_trigger_per_tool_zero_falls_to_fallback(self): + config = ToolOutputConfig(externalize_min_chars=50, tool_overrides={"bash": 0}, fallback_max_chars=200) + # externalize disabled for bash → only fallback can trigger + assert _effective_trigger("bash", config) == 200 + + def test_effective_trigger_returns_negative_when_fully_disabled(self): + config = ToolOutputConfig(externalize_min_chars=0, fallback_max_chars=0) + assert _effective_trigger("any", config) == -1 + + def test_pre_scan_does_not_short_circuit_per_tool_override(self): + """Regression: pre-scan must honor per-tool overrides, not just global threshold.""" + config = ToolOutputConfig(externalize_min_chars=50000, fallback_max_chars=0, tool_overrides={"sensitive": 100}) + msg = _tm("x" * 500, name="sensitive") + # 500 < global 50000 but > per-tool 100 → must still be flagged + assert _tool_message_over_budget(msg, config) is True + assert _needs_budget(msg, config) is True + + def test_exempt_tool_never_over_budget(self): + config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20, exempt_tools=["read_file"]) + msg = _tm("x" * 1000, name="read_file") + assert _tool_message_over_budget(msg, config) is False + + def test_model_call_pre_scan_skips_when_nothing_oversized(self): + """_patch_model_messages returns None (no list rebuild) when all messages are small.""" + config = ToolOutputConfig(externalize_min_chars=12000, fallback_max_chars=30000) + messages = [_tm("small", name="tool"), HumanMessage(content="hi"), _tm("also small", name="bash")] + assert _patch_model_messages(messages, config) is None + + +# =========================================================================== +# Middleware ordering in the chain +# =========================================================================== + + +class TestMiddlewareChainIntegration: + def test_budget_middleware_is_first_in_chain(self): + from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares + + app_config = AppConfig(sandbox=SandboxConfig(use="test")) + middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False) + + assert isinstance(middlewares[0], ToolOutputBudgetMiddleware) + + def test_budget_middleware_in_lead_chain(self): + from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares + + app_config = AppConfig(sandbox=SandboxConfig(use="test")) + middlewares = build_lead_runtime_middlewares(app_config=app_config, lazy_init=False) + + assert isinstance(middlewares[0], ToolOutputBudgetMiddleware) + + +# =========================================================================== +# Config version bump +# =========================================================================== + + +class TestConfigVersion: + def test_config_version_bumped(self): + import yaml + + example_path = os.path.join(os.path.dirname(__file__), "..", "..", "config.example.yaml") + if os.path.exists(example_path): + with open(example_path, encoding="utf-8") as f: + data = yaml.safe_load(f) + assert data.get("config_version", 0) >= 11 + + def test_config_example_has_tool_output_section(self): + import yaml + + example_path = os.path.join(os.path.dirname(__file__), "..", "..", "config.example.yaml") + if os.path.exists(example_path): + with open(example_path, encoding="utf-8") as f: + data = yaml.safe_load(f) + assert "tool_output" in data + tool_output = data["tool_output"] + assert tool_output["enabled"] is True + assert tool_output["externalize_min_chars"] == 12000 + assert "read_file" in tool_output["exempt_tools"] diff --git a/config.example.yaml b/config.example.yaml index 118b1be4d..e2d8dfb80 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -15,7 +15,7 @@ # ============================================================================ # Bump this number when the config schema changes. # Run `make config-upgrade` to merge new fields into your local config.yaml. -config_version: 10 +config_version: 11 # ============================================================================ # Logging @@ -544,6 +544,34 @@ tools: tool_search: enabled: false +# ============================================================================ +# Tool Output Budget Protection +# ============================================================================ +# Prevents oversized tool results from blowing the model context window. +# Outputs exceeding `externalize_min_chars` are persisted to disk and replaced +# with a compact preview + file reference. The model can read the full output +# via read_file. When disk persistence is unavailable, outputs exceeding +# `fallback_max_chars` are head+tail truncated instead. +# +# `exempt_tools` prevents persist→read→persist infinite loops for read tools. +# `tool_overrides` allows per-tool threshold customization. + +tool_output: + enabled: true + externalize_min_chars: 12000 + preview_head_chars: 2000 + preview_tail_chars: 1000 + fallback_max_chars: 30000 + fallback_head_chars: 8000 + fallback_tail_chars: 3000 + storage_subdir: ".tool-results" + exempt_tools: + - read_file + - read_file_tool + # tool_overrides: + # web_search: 8000 + # bash: 20000 + # ============================================================================ # Loop Detection Configuration # ============================================================================