refactor(tool-search): consolidate MCP metadata tag and harden deferred-tool setup (#3370)

Follow-up to #3342 (deferred MCP tool loading). Maintainability cleanup plus
hardening of malformed/empty tool_search queries; no change to the deferral
mechanism or search ranking.

- Add deerflow/tools/mcp_metadata.py as the single source of truth for the
  "deerflow_mcp" tag (MCP_TOOL_METADATA_KEY + tag_mcp_tool + public
  is_mcp_tool). Removes the duplicated magic string and the private,
  cross-module _is_mcp_tool import.
- tool_search.search: never raise on model-generated input. Extract
  _compile_catalog_regex (shared compile-with-literal-fallback); return empty
  for empty/whitespace queries and a bare "+" instead of matching everything
  or raising IndexError.
- DeferredToolSetup: document the empty-vs-populated invariant.
- build_deferred_tool_setup: comment the two distinct empty-return branches.
- _assemble_deferred: add return type, rename local to deferred_setup, build
  the final list with an explicit append.
- Tests: use tag_mcp_tool instead of per-file tag helpers; cover empty and
  bare-"+" queries.
This commit is contained in:
AochenShen99 2026-06-05 15:21:41 +08:00 committed by GitHub
parent 28b1da2172
commit 2bbc7879fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 123 additions and 47 deletions

View File

@ -18,7 +18,10 @@ middleware, and the async path inside ``TitleMiddleware``. Any new in-graph
``create_chat_model`` call must add to this list and pass the flag.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
@ -45,6 +48,11 @@ from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
from deerflow.tracing import build_tracing_callbacks
if TYPE_CHECKING:
from langchain.tools import BaseTool
from deerflow.tools.builtins.tool_search import DeferredToolSetup
logger = logging.getLogger(__name__)
@ -356,7 +364,7 @@ def _build_middlewares(
return middlewares
def _assemble_deferred(filtered_tools, *, enabled: bool):
def _assemble_deferred(filtered_tools: list[BaseTool], *, enabled: bool) -> tuple[list[BaseTool], DeferredToolSetup]:
"""Build the final tool list + deferred setup from a policy-filtered list.
Call AFTER tool-policy filtering so the deferred catalog never exposes a
@ -364,13 +372,16 @@ def _assemble_deferred(filtered_tools, *, enabled: bool):
and MCP tools survived filtering but no deferred set was recovered, raise
rather than silently binding their full schemas to the model.
"""
from deerflow.tools.builtins.tool_search import _is_mcp_tool, build_deferred_tool_setup
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
from deerflow.tools.mcp_metadata import is_mcp_tool
setup = build_deferred_tool_setup(filtered_tools, enabled=enabled)
if enabled and not setup.deferred_names and any(_is_mcp_tool(t) for t in filtered_tools):
deferred_setup = build_deferred_tool_setup(filtered_tools, enabled=enabled)
if enabled and not deferred_setup.deferred_names and any(is_mcp_tool(t) for t in filtered_tools):
raise RuntimeError("tool_search enabled and MCP tools survived policy filtering, but no deferred set was recovered — refusing to bind MCP schemas (fail-closed).")
final_tools = list(filtered_tools) + ([setup.tool_search_tool] if setup.tool_search_tool else [])
return final_tools, setup
final_tools = list(filtered_tools)
if deferred_setup.tool_search_tool:
final_tools.append(deferred_setup.tool_search_tool)
return final_tools, deferred_setup
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:

View File

@ -28,11 +28,25 @@ from langchain_core.tools import InjectedToolCallId, tool
from langchain_core.utils.function_calling import convert_to_openai_function
from langgraph.types import Command
from deerflow.tools.mcp_metadata import is_mcp_tool
logger = logging.getLogger(__name__)
MAX_RESULTS = 5 # Max tools returned per search
def _compile_catalog_regex(pattern: str) -> re.Pattern[str]:
"""Compile ``pattern`` case-insensitively, falling back to a literal match.
Search queries come from the model, so an invalid regex (e.g. an unbalanced
paren) must degrade to a literal substring match rather than raise.
"""
try:
return re.compile(pattern, re.IGNORECASE)
except re.error:
return re.compile(re.escape(pattern), re.IGNORECASE)
# ── Catalog ──
@ -56,22 +70,25 @@ class DeferredToolCatalog:
return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16]
def search(self, query: str) -> list[BaseTool]:
query = query.strip()
if not query:
return []
if query.startswith("select:"):
wanted = {n.strip() for n in query[7:].split(",")}
return [t for t in self.tools if t.name in wanted][:MAX_RESULTS]
if query.startswith("+"):
parts = query[1:].split(None, 1)
if not parts:
return [] # bare "+" with no required token — nothing to require
required = parts[0].lower()
candidates = [t for t in self.tools if required in t.name.lower()]
if len(parts) > 1:
candidates.sort(key=lambda t: _catalog_regex_score(parts[1], t), reverse=True)
return candidates[:MAX_RESULTS]
try:
regex = re.compile(query, re.IGNORECASE)
except re.error:
regex = re.compile(re.escape(query), re.IGNORECASE)
regex = _compile_catalog_regex(query)
scored: list[tuple[int, BaseTool]] = []
for t in self.tools:
searchable = f"{t.name} {t.description or ''}"
@ -82,10 +99,7 @@ class DeferredToolCatalog:
def _catalog_regex_score(pattern: str, t: BaseTool) -> int:
try:
regex = re.compile(pattern, re.IGNORECASE)
except re.error:
regex = re.compile(re.escape(pattern), re.IGNORECASE)
regex = _compile_catalog_regex(pattern)
return len(regex.findall(f"{t.name} {t.description or ''}"))
@ -94,15 +108,25 @@ def _catalog_regex_score(pattern: str, t: BaseTool) -> int:
@dataclass(frozen=True)
class DeferredToolSetup:
"""Result of assembling deferred-tool support for one agent build.
The three fields move as a unit, so callers branch on ``tool_search_tool``:
- **Empty** ``(None, frozenset(), None)``: deferral is disabled, or no MCP
tool survived policy filtering. Nothing is deferred bind tools as-is.
- **Populated**: ``tool_search_tool`` is appended to the agent's tools,
``deferred_names`` are withheld from the model until promoted, and
``catalog_hash`` scopes those promotions in graph state.
Invariant: ``tool_search_tool is None`` ``deferred_names`` is empty
``catalog_hash is None``.
"""
tool_search_tool: BaseTool | None
deferred_names: frozenset[str]
catalog_hash: str | None
def _is_mcp_tool(t: BaseTool) -> bool:
return (getattr(t, "metadata", None) or {}).get("deerflow_mcp") is True
def build_tool_search_tool(catalog: DeferredToolCatalog) -> BaseTool:
catalog_hash = catalog.hash
@ -141,11 +165,17 @@ def build_deferred_tool_setup(filtered_tools: list[BaseTool], *, enabled: bool)
Must be called after skill/agent tool-policy filtering so the catalog never
exposes a tool the current agent is not allowed to use.
Returns an empty setup (see :class:`DeferredToolSetup`) in two distinct
cases: deferral is disabled, or it is enabled but no MCP tool survived
filtering.
"""
if not enabled:
# Deferral disabled: defer nothing; the model binds every tool as before.
return DeferredToolSetup(None, frozenset(), None)
deferred = [t for t in filtered_tools if _is_mcp_tool(t)]
deferred = [t for t in filtered_tools if is_mcp_tool(t)]
if not deferred:
# Enabled, but no MCP tool to defer: same empty result, different reason.
return DeferredToolSetup(None, frozenset(), None)
catalog = DeferredToolCatalog(tuple(deferred))
return DeferredToolSetup(build_tool_search_tool(catalog), catalog.names, catalog.hash)

View File

@ -0,0 +1,29 @@
"""Single source of truth for the MCP-tool metadata tag.
A tool is "MCP-sourced" when it carries the ``deerflow_mcp`` metadata flag.
The tag is *written* where MCP tools are loaded (``tools.py``) and *read* by
deferred-tool assembly (``tool_search.py``) and the agent build site
(``agent.py``). Keeping the key, the tagger, and the predicate here means the
magic string lives in exactly one place, and readers import a public predicate
instead of a private cross-module helper.
This is a leaf module by design: it depends only on ``BaseTool`` so that any
module (including the tool loader) can import it without an import cycle.
"""
from __future__ import annotations
from langchain.tools import BaseTool
MCP_TOOL_METADATA_KEY = "deerflow_mcp"
def tag_mcp_tool(tool: BaseTool) -> BaseTool:
"""Mark ``tool`` as MCP-sourced. Mutates in place and returns it for chaining."""
tool.metadata = {**(tool.metadata or {}), MCP_TOOL_METADATA_KEY: True}
return tool
def is_mcp_tool(tool: BaseTool) -> bool:
"""True when ``tool`` carries the MCP-source tag written by :func:`tag_mcp_tool`."""
return (getattr(tool, "metadata", None) or {}).get(MCP_TOOL_METADATA_KEY) is True

View File

@ -7,6 +7,7 @@ from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
from deerflow.tools.mcp_metadata import tag_mcp_tool
from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__)
@ -132,7 +133,7 @@ def get_available_tools(
# the deferred catalog + tool_search tool are assembled per
# agent from the policy-filtered tool list.
for t in mcp_tools:
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
tag_mcp_tool(t)
except ImportError:
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
except Exception as e:

View File

@ -54,6 +54,23 @@ def test_search_invalid_regex_falls_back_to_literal():
assert cat.search("zzz(") == []
def test_search_empty_query_returns_empty(catalog):
# An empty / whitespace-only query is meaningless; rather than let the empty
# regex match every tool, search() returns nothing so the model gets a clear
# "no match" signal and re-queries instead of acting on noise.
assert catalog.search("") == []
assert catalog.search(" ") == []
def test_search_bare_plus_returns_empty(catalog):
# A "+" prefix with no required token is malformed model input. It must
# return no matches, not raise IndexError on parts[0]. " + " strips to "+",
# so it routes here too and must be handled the same way.
assert catalog.search("+") == []
assert catalog.search(" + ") == []
assert catalog.search("+ ") == []
def test_hash_stable_across_instances():
c1 = DeferredToolCatalog((alpha_search, beta_translate))
c2 = DeferredToolCatalog((beta_translate, alpha_search))

View File

@ -20,6 +20,7 @@ from langchain_core.tools import tool as as_tool
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
from deerflow.tools.mcp_metadata import tag_mcp_tool
@as_tool
@ -40,11 +41,6 @@ def mcp_other(x: str) -> str:
return x
def _tag(t):
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
return t
def test_tool_search_promotes_into_next_turn():
bound: list[list[str]] = []
@ -53,7 +49,7 @@ def test_tool_search_promotes_into_next_turn():
bound.append([getattr(t, "name", None) for t in tools])
return self
setup = build_deferred_tool_setup([active_tool, _tag(mcp_calc), _tag(mcp_other)], enabled=True)
setup = build_deferred_tool_setup([active_tool, tag_mcp_tool(mcp_calc), tag_mcp_tool(mcp_other)], enabled=True)
turn1 = AIMessage(content="", tool_calls=[{"name": "tool_search", "args": {"query": "select:mcp_calc"}, "id": "c1", "type": "tool_call"}])
turn2 = AIMessage(content="done")
model = RecordingModel(messages=iter([turn1, turn2]))

View File

@ -1,7 +1,8 @@
from langchain_core.tools import tool as as_tool
from langgraph.types import Command
from deerflow.tools.builtins.tool_search import DeferredToolCatalog, _is_mcp_tool, build_deferred_tool_setup, build_tool_search_tool
from deerflow.tools.builtins.tool_search import DeferredToolCatalog, build_deferred_tool_setup, build_tool_search_tool
from deerflow.tools.mcp_metadata import is_mcp_tool, tag_mcp_tool
@as_tool
@ -16,18 +17,13 @@ def local_echo(text: str) -> str:
return text
def _tag_mcp(t):
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
return t
def test_is_mcp_tool_reads_metadata():
assert _is_mcp_tool(_tag_mcp(mcp_calc)) is True
assert _is_mcp_tool(local_echo) is False
assert is_mcp_tool(tag_mcp_tool(mcp_calc)) is True
assert is_mcp_tool(local_echo) is False
def test_setup_disabled_returns_empty():
setup = build_deferred_tool_setup([_tag_mcp(mcp_calc), local_echo], enabled=False)
setup = build_deferred_tool_setup([tag_mcp_tool(mcp_calc), local_echo], enabled=False)
assert setup.tool_search_tool is None
assert setup.deferred_names == frozenset()
assert setup.catalog_hash is None
@ -40,7 +36,7 @@ def test_setup_no_mcp_returns_empty():
def test_setup_builds_from_mcp_survivors():
setup = build_deferred_tool_setup([_tag_mcp(mcp_calc), local_echo], enabled=True)
setup = build_deferred_tool_setup([tag_mcp_tool(mcp_calc), local_echo], enabled=True)
assert setup.deferred_names == frozenset({"mcp_calc"})
assert setup.tool_search_tool is not None
assert setup.tool_search_tool.name == "tool_search"

View File

@ -23,6 +23,7 @@ from deerflow.agents.middlewares.deferred_tool_filter_middleware import Deferred
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
from deerflow.tools.builtins.tool_search import DeferredToolSetup, build_deferred_tool_setup
from deerflow.tools.mcp_metadata import tag_mcp_tool
@as_tool
@ -37,11 +38,6 @@ def mcp_secret(x: str) -> str:
return x
def _tag(t):
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
return t
_BOUND: list[list[str]] = []
@ -52,7 +48,7 @@ class _RecordingModel(GenericFakeChatModel):
def _build_graph():
filtered = [active_tool, _tag(mcp_secret)]
filtered = [active_tool, tag_mcp_tool(mcp_secret)]
setup = build_deferred_tool_setup(filtered, enabled=True)
final = [*filtered, setup.tool_search_tool]
model = _RecordingModel(messages=iter([AIMessage(content="done")] * 4))
@ -107,18 +103,18 @@ def test_fail_closed_when_mcp_survives_without_setup(monkeypatch):
lambda tools, *, enabled: DeferredToolSetup(None, frozenset(), None),
)
with pytest.raises(RuntimeError, match="fail-closed"):
agentmod._assemble_deferred([_tag(mcp_secret)], enabled=True)
agentmod._assemble_deferred([tag_mcp_tool(mcp_secret)], enabled=True)
def test_subagent_reentry_does_not_touch_lead_state():
"""#2884: building a second (subagent) setup must not affect the lead's
middleware. With no shared registry/ContextVar, the lead middleware depends
only on its own deferred_names + the passed state."""
lead_setup = build_deferred_tool_setup([active_tool, _tag(mcp_secret)], enabled=True)
lead_setup = build_deferred_tool_setup([active_tool, tag_mcp_tool(mcp_secret)], enabled=True)
mw = DeferredToolFilterMiddleware(lead_setup.deferred_names, lead_setup.catalog_hash)
# Simulate a subagent build re-entering tool assembly with its own setup.
_ = build_deferred_tool_setup([_tag(mcp_secret)], enabled=True)
_ = build_deferred_tool_setup([tag_mcp_tool(mcp_secret)], enabled=True)
class _Req:
def __init__(self):
@ -154,7 +150,7 @@ def test_policy_denied_mcp_yields_no_tool_search_end_to_end():
tool_search (and does not fail-closed, because no MCP tool leaked through)."""
from deerflow.agents.lead_agent import agent as agentmod
filtered = filter_tools_by_skill_allowed_tools([active_tool, _tag(mcp_secret)], [_make_skill(["active_tool"])])
filtered = filter_tools_by_skill_allowed_tools([active_tool, tag_mcp_tool(mcp_secret)], [_make_skill(["active_tool"])])
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
assert [t.name for t in final_tools] == ["active_tool"]
@ -174,7 +170,7 @@ def test_tool_search_appended_after_policy_but_never_exposes_denied_tool():
from deerflow.agents.lead_agent import agent as agentmod
allowed = ["active_tool", "mcp_secret"] # permits the MCP tool, does NOT list tool_search
filtered = filter_tools_by_skill_allowed_tools([active_tool, _tag(mcp_secret)], [_make_skill(allowed)])
filtered = filter_tools_by_skill_allowed_tools([active_tool, tag_mcp_tool(mcp_secret)], [_make_skill(allowed)])
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
names = {t.name for t in final_tools}