diff --git a/backend/docs/MCP_SERVER.md b/backend/docs/MCP_SERVER.md index efe2ea0c4..b7320f8cc 100644 --- a/backend/docs/MCP_SERVER.md +++ b/backend/docs/MCP_SERVER.md @@ -45,6 +45,41 @@ Example: } ``` +## Custom Tool Interceptors + +You can register custom interceptors that run before every MCP tool call. This is useful for injecting per-request headers (e.g., user auth tokens from the LangGraph execution context), logging, or metrics. + +Declare interceptors in `extensions_config.json` using the `mcpInterceptors` field: + +```json +{ + "mcpInterceptors": [ + "my_package.mcp.auth:build_auth_interceptor" + ], + "mcpServers": { ... } +} +``` + +Each entry is a Python import path in `module:variable` format (resolved via `resolve_variable`). The variable must be a **no-arg builder function** that returns an async interceptor compatible with `MultiServerMCPClient`’s `tool_interceptors` interface, or `None` to skip. + +Example interceptor that injects auth headers from LangGraph metadata: + +```python +def build_auth_interceptor(): + async def interceptor(request, handler): + from langgraph.config import get_config + metadata = get_config().get("metadata", {}) + headers = dict(request.headers or {}) + if token := metadata.get("auth_token"): + headers["X-Auth-Token"] = token + return await handler(request.override(headers=headers)) + return interceptor +``` + +- A single string value is accepted and normalized to a one-element list. +- Invalid paths or builder failures are logged as warnings without blocking other interceptors. +- The builder return value must be `callable`; non-callable values are skipped with a warning. + ## How It Works MCP servers expose tools that are automatically discovered and integrated into DeerFlow’s agent system at runtime. Once enabled, these tools become available to agents without additional code changes. diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index 718ac2ba3..bcd50c645 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -12,6 +12,7 @@ from langchain_core.tools import BaseTool from deerflow.config.extensions_config import ExtensionsConfig from deerflow.mcp.client import build_servers_config from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers +from deerflow.reflection import resolve_variable logger = logging.getLogger(__name__) @@ -95,6 +96,27 @@ async def get_mcp_tools() -> list[BaseTool]: if oauth_interceptor is not None: tool_interceptors.append(oauth_interceptor) + # Load custom interceptors declared in extensions_config.json + # Format: "mcpInterceptors": ["pkg.module:builder_func", ...] + raw_interceptor_paths = (extensions_config.model_extra or {}).get("mcpInterceptors") + if isinstance(raw_interceptor_paths, str): + raw_interceptor_paths = [raw_interceptor_paths] + elif not isinstance(raw_interceptor_paths, list): + if raw_interceptor_paths is not None: + logger.warning(f"mcpInterceptors must be a list of strings, got {type(raw_interceptor_paths).__name__}; skipping") + raw_interceptor_paths = [] + for interceptor_path in raw_interceptor_paths: + try: + builder = resolve_variable(interceptor_path) + interceptor = builder() + if callable(interceptor): + tool_interceptors.append(interceptor) + logger.info(f"Loaded MCP interceptor: {interceptor_path}") + elif interceptor is not None: + logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping") + except Exception as e: + logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True) + client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True) # Get all tools from all servers diff --git a/backend/tests/test_mcp_custom_interceptors.py b/backend/tests/test_mcp_custom_interceptors.py new file mode 100644 index 000000000..08432de98 --- /dev/null +++ b/backend/tests/test_mcp_custom_interceptors.py @@ -0,0 +1,274 @@ +"""Tests for custom MCP tool interceptors loaded via extensions_config.json.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +from deerflow.mcp.tools import get_mcp_tools + + +def _make_patches(*, interceptor_paths=None): + """Set up mocks for get_mcp_tools() with optional custom interceptors. + + Returns a dict of patch context managers. + """ + mock_client = MagicMock() + mock_client.get_tools = AsyncMock(return_value=[]) + + extra = {} + if interceptor_paths is not None: + extra["mcpInterceptors"] = interceptor_paths + + return { + "client_cls": patch( + "langchain_mcp_adapters.client.MultiServerMCPClient", + return_value=mock_client, + ), + "from_file": patch( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + return_value=MagicMock( + model_extra=extra, + get_enabled_mcp_servers=MagicMock(return_value={}), + ), + ), + "build_servers": patch( + "deerflow.mcp.tools.build_servers_config", + return_value={"test-server": {}}, + ), + "oauth_headers": patch( + "deerflow.mcp.tools.get_initial_oauth_headers", + new_callable=AsyncMock, + return_value={}, + ), + "oauth_interceptor": patch( + "deerflow.mcp.tools.build_oauth_tool_interceptor", + return_value=None, + ), + } + + +def _get_interceptors(mock_cls): + """Extract the tool_interceptors list passed to MultiServerMCPClient.""" + kw = mock_cls.call_args + return kw.kwargs.get("tool_interceptors") or kw[1].get("tool_interceptors", []) + + +def test_custom_interceptor_loaded_and_appended(): + """A valid interceptor builder path is resolved, called, and appended to tool_interceptors.""" + + async def fake_interceptor(request, handler): + return await handler(request) + + def fake_builder(): + return fake_interceptor + + p = _make_patches(interceptor_paths=["my_package.auth:build_interceptor"]) + + with ( + p["client_cls"] as mock_cls, + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + patch("deerflow.mcp.tools.resolve_variable", return_value=fake_builder), + ): + asyncio.run(get_mcp_tools()) + + interceptors = _get_interceptors(mock_cls) + assert len(interceptors) == 1 + assert interceptors[0] is fake_interceptor + + +def test_multiple_custom_interceptors(): + """Multiple interceptor paths are all loaded in order.""" + + async def interceptor_a(request, handler): + return await handler(request) + + async def interceptor_b(request, handler): + return await handler(request) + + builders = { + "pkg.a:build_a": lambda: interceptor_a, + "pkg.b:build_b": lambda: interceptor_b, + } + + p = _make_patches(interceptor_paths=["pkg.a:build_a", "pkg.b:build_b"]) + + with ( + p["client_cls"] as mock_cls, + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + patch("deerflow.mcp.tools.resolve_variable", side_effect=lambda path: builders[path]), + ): + asyncio.run(get_mcp_tools()) + + interceptors = _get_interceptors(mock_cls) + assert len(interceptors) == 2 + assert interceptors[0] is interceptor_a + assert interceptors[1] is interceptor_b + + +def test_custom_interceptor_builder_returning_none_is_skipped(): + """If a builder returns None, it is not appended to the interceptor list.""" + p = _make_patches(interceptor_paths=["pkg.noop:build_noop"]) + + with ( + p["client_cls"] as mock_cls, + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + patch("deerflow.mcp.tools.resolve_variable", return_value=lambda: None), + ): + asyncio.run(get_mcp_tools()) + + assert len(_get_interceptors(mock_cls)) == 0 + + +def test_custom_interceptor_resolve_error_logs_warning_and_continues(): + """A broken interceptor path logs a warning and does not block tool loading.""" + p = _make_patches(interceptor_paths=["broken.path:does_not_exist"]) + + with ( + p["client_cls"], + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + patch("deerflow.mcp.tools.resolve_variable", side_effect=ImportError("no such module")), + patch("deerflow.mcp.tools.logger.warning") as mock_warn, + ): + tools = asyncio.run(get_mcp_tools()) + + assert tools == [] + mock_warn.assert_called_once() + assert "broken.path:does_not_exist" in mock_warn.call_args[0][0] + + +def test_custom_interceptor_builder_exception_logs_warning_and_continues(): + """If the builder function itself raises, the error is caught and logged.""" + + def exploding_builder(): + raise RuntimeError("builder exploded") + + p = _make_patches(interceptor_paths=["pkg.bad:exploding_builder"]) + + with ( + p["client_cls"], + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + patch("deerflow.mcp.tools.resolve_variable", return_value=exploding_builder), + patch("deerflow.mcp.tools.logger.warning") as mock_warn, + ): + tools = asyncio.run(get_mcp_tools()) + + assert tools == [] + mock_warn.assert_called_once() + assert "pkg.bad:exploding_builder" in mock_warn.call_args[0][0] + + +def test_no_mcp_interceptors_field_is_safe(): + """When mcpInterceptors is absent from config, no interceptors are added.""" + p = _make_patches(interceptor_paths=None) + + with ( + p["client_cls"] as mock_cls, + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + ): + asyncio.run(get_mcp_tools()) + + assert len(_get_interceptors(mock_cls)) == 0 + + +def test_custom_interceptor_coexists_with_oauth_interceptor(): + """Custom interceptors are appended after the OAuth interceptor.""" + + async def oauth_fn(request, handler): + return await handler(request) + + async def custom_fn(request, handler): + return await handler(request) + + p = _make_patches(interceptor_paths=["pkg.custom:build_custom"]) + + with ( + p["client_cls"] as mock_cls, + p["from_file"], + p["build_servers"], + p["oauth_headers"], + patch("deerflow.mcp.tools.build_oauth_tool_interceptor", return_value=oauth_fn), + patch("deerflow.mcp.tools.resolve_variable", return_value=lambda: custom_fn), + ): + asyncio.run(get_mcp_tools()) + + interceptors = _get_interceptors(mock_cls) + assert len(interceptors) == 2 + assert interceptors[0] is oauth_fn + assert interceptors[1] is custom_fn + + +def test_mcp_interceptors_single_string_is_normalized(): + """A single string value for mcpInterceptors is normalized to a list.""" + + async def fake_interceptor(request, handler): + return await handler(request) + + p = _make_patches(interceptor_paths="pkg.single:build_it") + + with ( + p["client_cls"] as mock_cls, + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + patch("deerflow.mcp.tools.resolve_variable", return_value=lambda: fake_interceptor), + ): + asyncio.run(get_mcp_tools()) + + assert len(_get_interceptors(mock_cls)) == 1 + + +def test_mcp_interceptors_invalid_type_logs_warning(): + """A non-list, non-string value for mcpInterceptors logs a warning and is skipped.""" + p = _make_patches(interceptor_paths=42) + + with ( + p["client_cls"] as mock_cls, + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + patch("deerflow.mcp.tools.logger.warning") as mock_warn, + ): + asyncio.run(get_mcp_tools()) + + assert len(_get_interceptors(mock_cls)) == 0 + mock_warn.assert_called_once() + assert "must be a list" in mock_warn.call_args[0][0] + + +def test_custom_interceptor_non_callable_return_logs_warning(): + """If a builder returns a non-callable value, it is skipped with a warning.""" + p = _make_patches(interceptor_paths=["pkg.bad:returns_string"]) + + with ( + p["client_cls"] as mock_cls, + p["from_file"], + p["build_servers"], + p["oauth_headers"], + p["oauth_interceptor"], + patch("deerflow.mcp.tools.resolve_variable", return_value=lambda: "not_a_callable"), + patch("deerflow.mcp.tools.logger.warning") as mock_warn, + ): + asyncio.run(get_mcp_tools()) + + assert len(_get_interceptors(mock_cls)) == 0 + mock_warn.assert_called_once() + assert "non-callable" in mock_warn.call_args[0][0] diff --git a/extensions_config.example.json b/extensions_config.example.json index dc0e224ea..118c5d6db 100644 --- a/extensions_config.example.json +++ b/extensions_config.example.json @@ -1,4 +1,7 @@ { + "mcpInterceptors": [ + "my_package.mcp.auth:build_auth_interceptor" + ], "mcpServers": { "filesystem": { "enabled": false,