mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix: gate deferred MCP tool execution (#2513)
* fix: gate deferred MCP tool execution * style: format deferred tool middleware * fix: address deferred tool review feedback
This commit is contained in:
parent
d78ed5c8f2
commit
ec8a8cae38
@ -16,6 +16,9 @@ from typing import override
|
|||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -35,7 +38,7 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
|||||||
if not registry:
|
if not registry:
|
||||||
return request
|
return request
|
||||||
|
|
||||||
deferred_names = {e.name for e in registry.entries}
|
deferred_names = registry.deferred_names
|
||||||
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
|
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
|
||||||
|
|
||||||
if len(active_tools) < len(request.tools):
|
if len(active_tools) < len(request.tools):
|
||||||
@ -43,6 +46,28 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
|||||||
|
|
||||||
return request.override(tools=active_tools)
|
return request.override(tools=active_tools)
|
||||||
|
|
||||||
|
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
|
||||||
|
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||||
|
|
||||||
|
registry = get_deferred_registry()
|
||||||
|
if not registry:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_name = str(request.tool_call.get("name") or "")
|
||||||
|
if not tool_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not registry.contains(tool_name):
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
|
||||||
|
return ToolMessage(
|
||||||
|
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
name=tool_name,
|
||||||
|
status="error",
|
||||||
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def wrap_model_call(
|
def wrap_model_call(
|
||||||
self,
|
self,
|
||||||
@ -51,6 +76,17 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
|||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
return handler(self._filter_tools(request))
|
return handler(self._filter_tools(request))
|
||||||
|
|
||||||
|
@override
|
||||||
|
def wrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||||
|
) -> ToolMessage | Command:
|
||||||
|
blocked = self._blocked_tool_message(request)
|
||||||
|
if blocked is not None:
|
||||||
|
return blocked
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
self,
|
self,
|
||||||
@ -58,3 +94,14 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
return await handler(self._filter_tools(request))
|
return await handler(self._filter_tools(request))
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def awrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||||
|
) -> ToolMessage | Command:
|
||||||
|
blocked = self._blocked_tool_message(request)
|
||||||
|
if blocked is not None:
|
||||||
|
return blocked
|
||||||
|
return await handler(request)
|
||||||
|
|||||||
@ -112,6 +112,15 @@ class DeferredToolRegistry:
|
|||||||
def entries(self) -> list[DeferredToolEntry]:
|
def entries(self) -> list[DeferredToolEntry]:
|
||||||
return list(self._entries)
|
return list(self._entries)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deferred_names(self) -> set[str]:
|
||||||
|
"""Names of tools that are still hidden from model binding."""
|
||||||
|
return {entry.name for entry in self._entries}
|
||||||
|
|
||||||
|
def contains(self, name: str) -> bool:
|
||||||
|
"""Return whether *name* is still deferred."""
|
||||||
|
return any(entry.name == name for entry in self._entries)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._entries)
|
return len(self._entries)
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
from langchain_core.tools import tool as langchain_tool
|
from langchain_core.tools import tool as langchain_tool
|
||||||
|
|
||||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||||
@ -83,6 +85,16 @@ class TestDeferredToolRegistry:
|
|||||||
assert "github_create_issue" in names
|
assert "github_create_issue" in names
|
||||||
assert "slack_send_message" in names
|
assert "slack_send_message" in names
|
||||||
|
|
||||||
|
def test_deferred_names(self, registry):
|
||||||
|
names = registry.deferred_names
|
||||||
|
assert "github_create_issue" in names
|
||||||
|
assert "slack_send_message" in names
|
||||||
|
assert len(names) == 6
|
||||||
|
|
||||||
|
def test_contains(self, registry):
|
||||||
|
assert registry.contains("github_create_issue") is True
|
||||||
|
assert registry.contains("not_registered") is False
|
||||||
|
|
||||||
def test_search_select_single(self, registry):
|
def test_search_select_single(self, registry):
|
||||||
results = registry.search("select:github_create_issue")
|
results = registry.search("select:github_create_issue")
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
@ -509,3 +521,89 @@ class TestToolSearchPromotion:
|
|||||||
assert "slack_send_message" not in remaining
|
assert "slack_send_message" not in remaining
|
||||||
assert "slack_list_channels" not in remaining
|
assert "slack_list_channels" not in remaining
|
||||||
assert len(registry) == 4
|
assert len(registry) == 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeferredToolExecutionGate:
|
||||||
|
def test_unpromoted_deferred_tool_call_is_blocked(self, registry):
|
||||||
|
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||||
|
|
||||||
|
set_deferred_registry(registry)
|
||||||
|
middleware = DeferredToolFilterMiddleware()
|
||||||
|
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||||
|
called = False
|
||||||
|
|
||||||
|
def handler(_request):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||||
|
|
||||||
|
result = middleware.wrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
assert called is False
|
||||||
|
assert isinstance(result, ToolMessage)
|
||||||
|
assert result.status == "error"
|
||||||
|
assert result.tool_call_id == "call-1"
|
||||||
|
assert "tool_search" in result.content
|
||||||
|
assert "github_create_issue" in result.content
|
||||||
|
|
||||||
|
def test_promoted_deferred_tool_call_is_allowed(self, registry):
|
||||||
|
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||||
|
|
||||||
|
registry.promote({"github_create_issue"})
|
||||||
|
set_deferred_registry(registry)
|
||||||
|
middleware = DeferredToolFilterMiddleware()
|
||||||
|
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||||
|
called = False
|
||||||
|
|
||||||
|
def handler(_request):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||||
|
|
||||||
|
result = middleware.wrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
assert called is True
|
||||||
|
assert isinstance(result, ToolMessage)
|
||||||
|
assert result.content == "executed"
|
||||||
|
|
||||||
|
def test_non_deferred_tool_call_is_allowed(self, registry):
|
||||||
|
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||||
|
|
||||||
|
set_deferred_registry(registry)
|
||||||
|
middleware = DeferredToolFilterMiddleware()
|
||||||
|
request = SimpleNamespace(tool_call={"name": "local_tool", "id": "call-1"})
|
||||||
|
called = False
|
||||||
|
|
||||||
|
def handler(_request):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return ToolMessage(content="executed", tool_call_id="call-1", name="local_tool")
|
||||||
|
|
||||||
|
result = middleware.wrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
assert called is True
|
||||||
|
assert isinstance(result, ToolMessage)
|
||||||
|
assert result.content == "executed"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_unpromoted_deferred_tool_call_is_blocked_async(self, registry):
|
||||||
|
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||||
|
|
||||||
|
set_deferred_registry(registry)
|
||||||
|
middleware = DeferredToolFilterMiddleware()
|
||||||
|
request = SimpleNamespace(tool_call={"name": "github_create_issue", "id": "call-1"})
|
||||||
|
called = False
|
||||||
|
|
||||||
|
async def handler(_request):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return ToolMessage(content="executed", tool_call_id="call-1", name="github_create_issue")
|
||||||
|
|
||||||
|
result = await middleware.awrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
assert called is False
|
||||||
|
assert isinstance(result, ToolMessage)
|
||||||
|
assert result.status == "error"
|
||||||
|
assert result.tool_call_id == "call-1"
|
||||||
|
assert "tool_search" in result.content
|
||||||
|
assert "github_create_issue" in result.content
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user