fix: gate deferred MCP tool execution (#2513)

* fix: gate deferred MCP tool execution

* style: format deferred tool middleware

* fix: address deferred tool review feedback
This commit is contained in:
DanielWalnut 2026-04-24 22:45:41 +08:00 committed by GitHub
parent d78ed5c8f2
commit ec8a8cae38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 155 additions and 1 deletions

View File

@ -16,6 +16,9 @@ from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,7 +38,7 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
if not registry: if not registry:
return request return request
deferred_names = {e.name for e in registry.entries} deferred_names = registry.deferred_names
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names] active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
if len(active_tools) < len(request.tools): if len(active_tools) < len(request.tools):
@ -43,6 +46,28 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
return request.override(tools=active_tools) return request.override(tools=active_tools)
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return None
tool_name = str(request.tool_call.get("name") or "")
if not tool_name:
return None
if not registry.contains(tool_name):
return None
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
return ToolMessage(
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override @override
def wrap_model_call( def wrap_model_call(
self, self,
@ -51,6 +76,17 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
) -> ModelCallResult: ) -> ModelCallResult:
return handler(self._filter_tools(request)) return handler(self._filter_tools(request))
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return handler(request)
@override @override
async def awrap_model_call( async def awrap_model_call(
self, self,
@ -58,3 +94,14 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]], handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult: ) -> ModelCallResult:
return await handler(self._filter_tools(request)) return await handler(self._filter_tools(request))
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return await handler(request)

View File

@ -112,6 +112,15 @@ class DeferredToolRegistry:
def entries(self) -> list[DeferredToolEntry]: def entries(self) -> list[DeferredToolEntry]:
return list(self._entries) return list(self._entries)
@property
def deferred_names(self) -> set[str]:
"""Names of tools that are still hidden from model binding."""
return {entry.name for entry in self._entries}
def contains(self, name: str) -> bool:
"""Return whether *name* is still deferred."""
return any(entry.name == name for entry in self._entries)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._entries) return len(self._entries)

View File

@ -2,8 +2,10 @@
import json import json
import sys import sys
from types import SimpleNamespace
import pytest import pytest
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool as langchain_tool from langchain_core.tools import tool as langchain_tool
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
@ -83,6 +85,16 @@ class TestDeferredToolRegistry:
assert "github_create_issue" in names assert "github_create_issue" in names
assert "slack_send_message" in names assert "slack_send_message" in names
def test_deferred_names(self, registry):
names = registry.deferred_names
assert "github_create_issue" in names
assert "slack_send_message" in names
assert len(names) == 6
def test_contains(self, registry):
assert registry.contains("github_create_issue") is True
assert registry.contains("not_registered") is False
def test_search_select_single(self, registry): def test_search_select_single(self, registry):
results = registry.search("select:github_create_issue") results = registry.search("select:github_create_issue")
assert len(results) == 1 assert len(results) == 1
@ -509,3 +521,89 @@ class TestToolSearchPromotion:
assert "slack_send_message" not in remaining assert "slack_send_message" not in remaining
assert "slack_list_channels" not in remaining assert "slack_list_channels" not in remaining
assert len(registry) == 4 assert len(registry) == 4
class TestDeferredToolExecutionGate:
def test_unpromoted_deferred_tool_call_is_blocked(self, registry):
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
set_deferred_registry(registry)
middleware = DeferredToolFilterMiddleware()
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
called = False
def handler(_request):
nonlocal called
called = True
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
result = middleware.wrap_tool_call(request, handler)
assert called is False
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert result.tool_call_id == "call-1"
assert "tool_search" in result.content
assert "github_create_issue" in result.content
def test_promoted_deferred_tool_call_is_allowed(self, registry):
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
registry.promote({"github_create_issue"})
set_deferred_registry(registry)
middleware = DeferredToolFilterMiddleware()
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
called = False
def handler(_request):
nonlocal called
called = True
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
result = middleware.wrap_tool_call(request, handler)
assert called is True
assert isinstance(result, ToolMessage)
assert result.content == "executed"
def test_non_deferred_tool_call_is_allowed(self, registry):
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
set_deferred_registry(registry)
middleware = DeferredToolFilterMiddleware()
request = SimpleNamespace(tool_call={"name": "local_tool", "id": "call-1"})
called = False
def handler(_request):
nonlocal called
called = True
return ToolMessage(content="executed", tool_call_id="call-1", name="local_tool")
result = middleware.wrap_tool_call(request, handler)
assert called is True
assert isinstance(result, ToolMessage)
assert result.content == "executed"
@pytest.mark.anyio
async def test_unpromoted_deferred_tool_call_is_blocked_async(self, registry):
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
set_deferred_registry(registry)
middleware = DeferredToolFilterMiddleware()
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
called = False
async def handler(_request):
nonlocal called
called = True
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
result = await middleware.awrap_tool_call(request, handler)
assert called is False
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert result.tool_call_id == "call-1"
assert "tool_search" in result.content
assert "github_create_issue" in result.content