deer-flow/backend/tests/test_tool_error_handling_middleware.py
AochenShen99 3b6dd0a4e3
feat(subagents): extend deferred MCP tool loading to subagents (#3432)
* feat(subagents): extend deferred MCP tool loading to subagents (#3341)

Subagents now reuse the lead agent's deferred-tool path: when
tool_search.enabled, MCP tool schemas are withheld from the model and
surfaced by name in <available-deferred-tools>, fetched on demand via the
generated tool_search helper. DeferredToolFilterMiddleware deterministically
rewrites request.tools to hide the deferred schemas (the prompt section is
discovery only, not enforcement).

Consolidates the assembly into deerflow.tools.builtins.tool_search, now the
single home for both assemble_deferred_tools (centralized fail-closed guard,
replacing the lead-only private _assemble_deferred) and the relocated
get_deferred_tools_prompt_section. Shared by every build path: lead agent,
embedded client, and subagent executor.

tool_search is appended after the subagent's name-level tool policy and is
treated as infrastructure: its catalog is built from the already
policy-filtered list, so it can never surface a tool the policy denied.

Follow-up to #3370. Fixes #3341.

* test(subagents): assert the real middleware builder emits a working deferred filter (#3341)

The existing recipe test hand-constructs DeferredToolFilterMiddleware, so it
cannot catch a regression in how build_subagent_runtime_middlewares (the call
executor._create_agent actually makes) wires the deferred setup into the
filter. Add a test that sources the filter from the real builder given a real
setup and runs it through a graph: a wrong catalog hash would silently stop
promotion, a dropped filter would stop hiding — both now caught.

Running the full real middleware stack is intentionally avoided (the other
runtime middlewares need sandbox/thread infra to execute, which would make the
test flaky); their attachment + ordering before Safety stays locked in
test_tool_error_handling_middleware.py.

* test(subagents): keep executor tests config-free in CI

* chore: trigger ci

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-08 23:17:22 +08:00

298 lines
11 KiB
Python

import sys
from types import ModuleType, SimpleNamespace
import pytest
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphInterrupt
from deerflow.agents.middlewares.tool_error_handling_middleware import (
ToolErrorHandlingMiddleware,
build_subagent_runtime_middlewares,
)
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
tool_call = {"name": name}
if tool_call_id is not None:
tool_call["id"] = tool_call_id
return SimpleNamespace(tool_call=tool_call)
def _module(name: str, **attrs):
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
def _make_app_config(*, supports_vision: bool = False) -> AppConfig:
return AppConfig(
models=[
ModelConfig(
name="test-model",
display_name="test-model",
description=None,
use="langchain_openai:ChatOpenAI",
model="test-model",
supports_vision=supports_vision,
)
],
sandbox=SandboxConfig(use="test"),
guardrails=GuardrailsConfig(enabled=False),
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
)
def _stub_runtime_middleware_imports(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeMiddleware:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class FakeLLMErrorHandlingMiddleware:
def __init__(self, *, app_config):
self.app_config = app_config
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.llm_error_handling_middleware",
_module(
"deerflow.agents.middlewares.llm_error_handling_middleware",
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.thread_data_middleware",
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.sandbox.middleware",
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.dangling_tool_call_middleware",
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.sandbox_audit_middleware",
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
)
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
captured: dict[str, object] = {}
class FakeMiddleware:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class FakeLLMErrorHandlingMiddleware:
def __init__(self, *, app_config):
captured["app_config"] = app_config
app_config = _make_app_config()
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.llm_error_handling_middleware",
_module(
"deerflow.agents.middlewares.llm_error_handling_middleware",
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.thread_data_middleware",
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.sandbox.middleware",
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.dangling_tool_call_middleware",
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
)
monkeypatch.setitem(
sys.modules,
"deerflow.agents.middlewares.sandbox_audit_middleware",
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
# 7 baseline (ToolOutputBudget, ThreadData, Sandbox, DanglingToolCall,
# LLMErrorHandling, SandboxAudit, ToolErrorHandling)
# + 1 SafetyFinishReasonMiddleware (enabled by default).
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.tool_output_budget_middleware import ToolOutputBudgetMiddleware
assert len(middlewares) == 8
assert isinstance(middlewares[0], ToolOutputBudgetMiddleware)
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
def test_wrap_tool_call_passthrough_on_success():
middleware = ToolErrorHandlingMiddleware()
req = _request()
expected = ToolMessage(content="ok", tool_call_id="tc-1", name="web_search")
result = middleware.wrap_tool_call(req, lambda _req: expected)
assert result is expected
def test_wrap_tool_call_returns_error_tool_message_on_exception():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="web_search", tool_call_id="tc-42")
def _boom(_req):
raise RuntimeError("network down")
result = middleware.wrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "tc-42"
assert result.name == "web_search"
assert result.status == "error"
assert "Tool 'web_search' failed" in result.text
assert "network down" in result.text
def test_wrap_tool_call_uses_fallback_tool_call_id_when_missing():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="mcp_tool", tool_call_id=None)
def _boom(_req):
raise ValueError("bad request")
result = middleware.wrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "missing_tool_call_id"
assert result.name == "mcp_tool"
assert result.status == "error"
def test_wrap_tool_call_reraises_graph_interrupt():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="ask_clarification", tool_call_id="tc-int")
def _interrupt(_req):
raise GraphInterrupt(())
with pytest.raises(GraphInterrupt):
middleware.wrap_tool_call(req, _interrupt)
@pytest.mark.anyio
async def test_awrap_tool_call_returns_error_tool_message_on_exception():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="mcp_tool", tool_call_id="tc-async")
async def _boom(_req):
raise TimeoutError("request timed out")
result = await middleware.awrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "tc-async"
assert result.name == "mcp_tool"
assert result.status == "error"
assert "request timed out" in result.text
@pytest.mark.anyio
async def test_awrap_tool_call_reraises_graph_interrupt():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="ask_clarification", tool_call_id="tc-int-async")
async def _interrupt(_req):
raise GraphInterrupt(())
with pytest.raises(GraphInterrupt):
await middleware.awrap_tool_call(req, _interrupt)
def test_subagent_runtime_middlewares_include_view_image_for_vision_model(monkeypatch):
app_config = _make_app_config(supports_vision=True)
_stub_runtime_middleware_imports(monkeypatch)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
def test_subagent_runtime_middlewares_include_view_image_for_default_vision_model(monkeypatch):
app_config = _make_app_config(supports_vision=True)
_stub_runtime_middleware_imports(monkeypatch)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=None)
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
def test_subagent_runtime_middlewares_skip_view_image_for_text_model(monkeypatch):
app_config = _make_app_config(supports_vision=False)
_stub_runtime_middleware_imports(monkeypatch)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
assert not any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
def test_subagent_runtime_middlewares_attach_deferred_filter_when_setup_has_names(monkeypatch):
"""A subagent built with deferred MCP tools gets DeferredToolFilterMiddleware, positioned before SafetyFinishReasonMiddleware (mirrors the lead ordering)."""
from langchain_core.tools import tool as as_tool
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
from deerflow.tools.mcp_metadata import tag_mcp_tool
app_config = _make_app_config()
_stub_runtime_middleware_imports(monkeypatch)
@as_tool
def mcp_thing(x: str) -> str:
"deferred mcp tool"
return x
setup = build_deferred_tool_setup([tag_mcp_tool(mcp_thing)], enabled=True)
assert setup.deferred_names # sanity: populated setup
middlewares = build_subagent_runtime_middlewares(app_config=app_config, deferred_setup=setup)
filters = [m for m in middlewares if isinstance(m, DeferredToolFilterMiddleware)]
assert len(filters) == 1
filter_idx = next(i for i, m in enumerate(middlewares) if isinstance(m, DeferredToolFilterMiddleware))
safety_idx = next(i for i, m in enumerate(middlewares) if isinstance(m, SafetyFinishReasonMiddleware))
assert filter_idx < safety_idx
def test_subagent_runtime_middlewares_skip_deferred_filter_without_names(monkeypatch):
"""No deferred setup (disabled / no MCP tool) -> no DeferredToolFilterMiddleware."""
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.builtins.tool_search import DeferredToolSetup
app_config = _make_app_config()
_stub_runtime_middleware_imports(monkeypatch)
for setup in (None, DeferredToolSetup(None, frozenset(), None)):
middlewares = build_subagent_runtime_middlewares(app_config=app_config, deferred_setup=setup)
assert not any(isinstance(m, DeferredToolFilterMiddleware) for m in middlewares)