deer-flow/backend/tests/test_tool_error_handling_middleware.py
Xinmin Zeng ca487578a4
feat(agent): add ToolOutputBudgetMiddleware for oversized tool output protection (#3303)
* feat(agent): add ToolOutputBudgetMiddleware for oversized tool output protection

Closes #3289. Adds a unified middleware that enforces per-result budgets
on ALL tool outputs (MCP, sandbox, community, custom), preventing
oversized external tool results from blowing the model context window.

Design informed by claude-code (persistToolResult), hermes-agent
(tool_result_storage), and pi (OutputAccumulator) — the three most
mature implementations in production coding-agent frameworks.

Key features:
- Disk externalization: oversized outputs written to thread-local
  .tool-results/ directory, replaced with compact preview + file
  reference. Model can read full output via read_file with offset/limit.
- Fallback truncation: head+tail truncation when disk is unavailable
  (no thread_data, write failure), ensuring the context is always
  protected.
- read_file exemption: prevents persist-read-persist infinite loops
  (independently discovered by claude-code, hermes-agent, and pi).
- Per-tool threshold overrides via config.
- Line-boundary-aware truncation (no partial lines in previews).
- Multimodal content passthrough (images/structured blocks skip budget).
- Historical ToolMessage patching in wrap_model_call for checkpoint
  recovery scenarios.

Related: #3222 (design RFC), #1844 (comprehensive context management),
#3137 (write_file args compaction), #1677 (sandbox tool truncation).

* test: add MCP content_and_artifact format coverage

Add 5 tests for MCP tool output format (list of content blocks):
- text content blocks are extracted and budgeted
- multiple text blocks are joined and budgeted
- image content blocks are skipped (multimodal passthrough)
- mixed text+image blocks are skipped
- small text blocks pass through unchanged

Total test count: 59 (was 54).

* fix(agent): address Codex review findings for ToolOutputBudgetMiddleware

Three issues identified by Codex code review, all fixed:

1. `enabled` config field was unused — middleware now checks
   `config.enabled` and skips all processing when disabled.

2. `_build_fallback` could exceed `fallback_max_chars` — the marker
   text itself (~139 chars) was not deducted from the budget. Now
   pre-computes marker overhead and falls back to hard slice when
   max_chars is smaller than the marker.

3. Sync file I/O in async path — `awrap_tool_call` now delegates
   `_patch_result` to `asyncio.to_thread` to avoid blocking the
   event loop during disk writes.

Tests updated to use realistic fallback_max_chars values (500+)
that can accommodate the marker overhead, plus two new tests:
- `test_result_never_exceeds_max_chars` (parametric across sizes)
- `test_very_small_max_chars_does_not_crash`

* fix(agent): address Copilot review — path traversal, async perf, shared config

1. Path traversal defense: sanitize tool_name via _sanitize_tool_name()
   (strips separators, .., absolute paths), validate storage_subdir is
   relative, and verify resolved filepath stays inside storage_dir.

2. Async hot-path optimization: add _needs_budget() cheap check before
   asyncio.to_thread offload — small outputs (99% of calls) skip the
   thread overhead entirely.

3. Replace shared module-level _DEFAULT_CONFIG with _default_config()
   factory to prevent cross-instance mutation of mutable fields.

12 new tests: TestSanitizeToolName (5), TestExternalizePathTraversal (3),
TestNeedsBudget (4).

* fix(agent): correct preview hint to match read_file actual API

read_file uses start_line/end_line (1-indexed line numbers), not
offset/limit. The previous wording was copied from hermes-agent
which has a different read_file interface.

* perf(agent): hoist hot-path imports, add model-call pre-scan (review #3303)

Address maintainer review feedback:

1. Hoist inline imports to module level — `import asyncio` (was in
   awrap_tool_call hot path) and `from dataclasses import replace`
   (was in _patch_result) now live at module top.

2. Add a cheap pre-scan to _patch_model_messages so the historical
   message list is not rebuilt on every model call when nothing is
   oversized (the common case once results are budgeted at tool-call
   time). Also adds the same _needs_budget gate to the sync
   wrap_tool_call for symmetry with awrap_tool_call.

The pre-scan is refactored into per-tool-aware helpers
(_effective_trigger / _tool_message_over_budget) that mirror the exact
trigger conditions in _budget_content — including tool_overrides — so
the fast-path can never produce a false negative (silently skipping
budgeting for a tool with a low per-tool threshold).

7 new regression tests lock the per-tool-override-through-pre-scan path
and the model-call early return.

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-29 22:59:26 +08:00

256 lines
9.0 KiB
Python

import sys
from types import ModuleType, SimpleNamespace
import pytest
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphInterrupt
from deerflow.agents.middlewares.tool_error_handling_middleware import (
ToolErrorHandlingMiddleware,
build_subagent_runtime_middlewares,
)
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
tool_call = {"name": name}
if tool_call_id is not None:
tool_call["id"] = tool_call_id
return SimpleNamespace(tool_call=tool_call)
def _module(name: str, **attrs):
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
def _make_app_config(*, supports_vision: bool = False) -> AppConfig:
return AppConfig(
models=[
ModelConfig(
name="test-model",
display_name="test-model",
description=None,
use="langchain_openai:ChatOpenAI",
model="test-model",
supports_vision=supports_vision,
)
],
sandbox=SandboxConfig(use="test"),
guardrails=GuardrailsConfig(enabled=False),
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
)
def _stub_runtime_middleware_imports(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeMiddleware:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class FakeLLMErrorHandlingMiddleware:
def __init__(self, *, app_config):
self.app_config = app_config
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.llm_error_handling_middleware",
_module(
"deerflow.agents.middlewares.llm_error_handling_middleware",
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.thread_data_middleware",
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.sandbox.middleware",
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.dangling_tool_call_middleware",
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.sandbox_audit_middleware",
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
)
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
captured: dict[str, object] = {}
class FakeMiddleware:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class FakeLLMErrorHandlingMiddleware:
def __init__(self, *, app_config):
captured["app_config"] = app_config
app_config = _make_app_config()
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.llm_error_handling_middleware",
_module(
"deerflow.agents.middlewares.llm_error_handling_middleware",
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.thread_data_middleware",
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.sandbox.middleware",
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.dangling_tool_call_middleware",
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.sandbox_audit_middleware",
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
# 7 baseline (ToolOutputBudget, ThreadData, Sandbox, DanglingToolCall,
# LLMErrorHandling, SandboxAudit, ToolErrorHandling)
# + 1 SafetyFinishReasonMiddleware (enabled by default).
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.tool_output_budget_middleware import ToolOutputBudgetMiddleware
assert len(middlewares) == 8
assert isinstance(middlewares[0], ToolOutputBudgetMiddleware)
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
def test_wrap_tool_call_passthrough_on_success():
middleware = ToolErrorHandlingMiddleware()
req = _request()
expected = ToolMessage(content="ok", tool_call_id="tc-1", name="web_search")
result = middleware.wrap_tool_call(req, lambda _req: expected)
assert result is expected
def test_wrap_tool_call_returns_error_tool_message_on_exception():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="web_search", tool_call_id="tc-42")
def _boom(_req):
raise RuntimeError("network down")
result = middleware.wrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "tc-42"
assert result.name == "web_search"
assert result.status == "error"
assert "Tool 'web_search' failed" in result.text
assert "network down" in result.text
def test_wrap_tool_call_uses_fallback_tool_call_id_when_missing():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="mcp_tool", tool_call_id=None)
def _boom(_req):
raise ValueError("bad request")
result = middleware.wrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "missing_tool_call_id"
assert result.name == "mcp_tool"
assert result.status == "error"
def test_wrap_tool_call_reraises_graph_interrupt():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="ask_clarification", tool_call_id="tc-int")
def _interrupt(_req):
raise GraphInterrupt(())
with pytest.raises(GraphInterrupt):
middleware.wrap_tool_call(req, _interrupt)
@pytest.mark.anyio
async def test_awrap_tool_call_returns_error_tool_message_on_exception():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="mcp_tool", tool_call_id="tc-async")
async def _boom(_req):
raise TimeoutError("request timed out")
result = await middleware.awrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "tc-async"
assert result.name == "mcp_tool"
assert result.status == "error"
assert "request timed out" in result.text
@pytest.mark.anyio
async def test_awrap_tool_call_reraises_graph_interrupt():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="ask_clarification", tool_call_id="tc-int-async")
async def _interrupt(_req):
raise GraphInterrupt(())
with pytest.raises(GraphInterrupt):
await middleware.awrap_tool_call(req, _interrupt)
def test_subagent_runtime_middlewares_include_view_image_for_vision_model(monkeypatch):
app_config = _make_app_config(supports_vision=True)
_stub_runtime_middleware_imports(monkeypatch)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
def test_subagent_runtime_middlewares_include_view_image_for_default_vision_model(monkeypatch):
app_config = _make_app_config(supports_vision=True)
_stub_runtime_middleware_imports(monkeypatch)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=None)
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
def test_subagent_runtime_middlewares_skip_view_image_for_text_model(monkeypatch):
app_config = _make_app_config(supports_vision=False)
_stub_runtime_middleware_imports(monkeypatch)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
assert not any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)