From ec8a8cae38456ece2b0f9a6b32c42382127c5f0e Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Fri, 24 Apr 2026 22:45:41 +0800 Subject: [PATCH] fix: gate deferred MCP tool execution (#2513) * fix: gate deferred MCP tool execution * style: format deferred tool middleware * fix: address deferred tool review feedback --- .../deferred_tool_filter_middleware.py | 49 +++++++++- .../deerflow/tools/builtins/tool_search.py | 9 ++ backend/tests/test_tool_search.py | 98 +++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/deferred_tool_filter_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/deferred_tool_filter_middleware.py index 604cdf37c..f92d90158 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/deferred_tool_filter_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/deferred_tool_filter_middleware.py @@ -16,6 +16,9 @@ from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest +from langgraph.types import Command logger = logging.getLogger(__name__) @@ -35,7 +38,7 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]): if not registry: return request - deferred_names = {e.name for e in registry.entries} + deferred_names = registry.deferred_names active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names] if len(active_tools) < len(request.tools): @@ -43,6 +46,28 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]): return request.override(tools=active_tools) + def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None: + from deerflow.tools.builtins.tool_search import get_deferred_registry + + registry = get_deferred_registry() + if not registry: + return None + + tool_name = str(request.tool_call.get("name") or "") + if not tool_name: + return None + + if not registry.contains(tool_name): + return None + + tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id") + return ToolMessage( + content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."), + tool_call_id=tool_call_id, + name=tool_name, + status="error", + ) + @override def wrap_model_call( self, @@ -51,6 +76,17 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]): ) -> ModelCallResult: return handler(self._filter_tools(request)) + @override + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + blocked = self._blocked_tool_message(request) + if blocked is not None: + return blocked + return handler(request) + @override async def awrap_model_call( self, @@ -58,3 +94,14 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]): handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: return await handler(self._filter_tools(request)) + + @override + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + blocked = self._blocked_tool_message(request) + if blocked is not None: + return blocked + return await handler(request) diff --git a/backend/packages/harness/deerflow/tools/builtins/tool_search.py b/backend/packages/harness/deerflow/tools/builtins/tool_search.py index ffbe2060f..88f4e3112 100644 --- a/backend/packages/harness/deerflow/tools/builtins/tool_search.py +++ b/backend/packages/harness/deerflow/tools/builtins/tool_search.py @@ -112,6 +112,15 @@ class DeferredToolRegistry: def entries(self) -> list[DeferredToolEntry]: return list(self._entries) + @property + def deferred_names(self) -> set[str]: + """Names of tools that are still hidden from model binding.""" + return {entry.name for entry in self._entries} + + def contains(self, name: str) -> bool: + """Return whether *name* is still deferred.""" + return any(entry.name == name for entry in self._entries) + def __len__(self) -> int: return len(self._entries) diff --git a/backend/tests/test_tool_search.py b/backend/tests/test_tool_search.py index 8f71144c5..428bfec3d 100644 --- a/backend/tests/test_tool_search.py +++ b/backend/tests/test_tool_search.py @@ -2,8 +2,10 @@ import json import sys +from types import SimpleNamespace import pytest +from langchain_core.messages import ToolMessage from langchain_core.tools import tool as langchain_tool from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict @@ -83,6 +85,16 @@ class TestDeferredToolRegistry: assert "github_create_issue" in names assert "slack_send_message" in names + def test_deferred_names(self, registry): + names = registry.deferred_names + assert "github_create_issue" in names + assert "slack_send_message" in names + assert len(names) == 6 + + def test_contains(self, registry): + assert registry.contains("github_create_issue") is True + assert registry.contains("not_registered") is False + def test_search_select_single(self, registry): results = registry.search("select:github_create_issue") assert len(results) == 1 @@ -509,3 +521,89 @@ class TestToolSearchPromotion: assert "slack_send_message" not in remaining assert "slack_list_channels" not in remaining assert len(registry) == 4 + + +class TestDeferredToolExecutionGate: + def test_unpromoted_deferred_tool_call_is_blocked(self, registry): + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + + set_deferred_registry(registry) + middleware = DeferredToolFilterMiddleware() + request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"}) + called = False + + def handler(_request): + nonlocal called + called = True + return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue") + + result = middleware.wrap_tool_call(request, handler) + + assert called is False + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert result.tool_call_id == "call-1" + assert "tool_search" in result.content + assert "github_create_issue" in result.content + + def test_promoted_deferred_tool_call_is_allowed(self, registry): + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + + registry.promote({"github_create_issue"}) + set_deferred_registry(registry) + middleware = DeferredToolFilterMiddleware() + request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"}) + called = False + + def handler(_request): + nonlocal called + called = True + return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue") + + result = middleware.wrap_tool_call(request, handler) + + assert called is True + assert isinstance(result, ToolMessage) + assert result.content == "executed" + + def test_non_deferred_tool_call_is_allowed(self, registry): + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + + set_deferred_registry(registry) + middleware = DeferredToolFilterMiddleware() + request = SimpleNamespace(tool_call={"name": "local_tool", "id": "call-1"}) + called = False + + def handler(_request): + nonlocal called + called = True + return ToolMessage(content="executed", tool_call_id="call-1", name="local_tool") + + result = middleware.wrap_tool_call(request, handler) + + assert called is True + assert isinstance(result, ToolMessage) + assert result.content == "executed" + + @pytest.mark.anyio + async def test_unpromoted_deferred_tool_call_is_blocked_async(self, registry): + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + + set_deferred_registry(registry) + middleware = DeferredToolFilterMiddleware() + request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"}) + called = False + + async def handler(_request): + nonlocal called + called = True + return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue") + + result = await middleware.awrap_tool_call(request, handler) + + assert called is False + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert result.tool_call_id == "call-1" + assert "tool_search" in result.content + assert "github_create_issue" in result.content