mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-09 17:12:01 +00:00
refactor(tool-search): consolidate MCP metadata tag and harden deferred-tool setup (#3370)
Follow-up to #3342 (deferred MCP tool loading). Maintainability cleanup plus hardening of malformed/empty tool_search queries; no change to the deferral mechanism or search ranking. - Add deerflow/tools/mcp_metadata.py as the single source of truth for the "deerflow_mcp" tag (MCP_TOOL_METADATA_KEY + tag_mcp_tool + public is_mcp_tool). Removes the duplicated magic string and the private, cross-module _is_mcp_tool import. - tool_search.search: never raise on model-generated input. Extract _compile_catalog_regex (shared compile-with-literal-fallback); return empty for empty/whitespace queries and a bare "+" instead of matching everything or raising IndexError. - DeferredToolSetup: document the empty-vs-populated invariant. - build_deferred_tool_setup: comment the two distinct empty-return branches. - _assemble_deferred: add return type, rename local to deferred_setup, build the final list with an explicit append. - Tests: use tag_mcp_tool instead of per-file tag helpers; cover empty and bare-"+" queries.
This commit is contained in:
parent
28b1da2172
commit
2bbc7879fa
@ -18,7 +18,10 @@ middleware, and the async path inside ``TitleMiddleware``. Any new in-graph
|
||||
``create_chat_model`` call must add to this list and pass the flag.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
@ -45,6 +48,11 @@ from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.tracing import build_tracing_callbacks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolSetup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -356,7 +364,7 @@ def _build_middlewares(
|
||||
return middlewares
|
||||
|
||||
|
||||
def _assemble_deferred(filtered_tools, *, enabled: bool):
|
||||
def _assemble_deferred(filtered_tools: list[BaseTool], *, enabled: bool) -> tuple[list[BaseTool], DeferredToolSetup]:
|
||||
"""Build the final tool list + deferred setup from a policy-filtered list.
|
||||
|
||||
Call AFTER tool-policy filtering so the deferred catalog never exposes a
|
||||
@ -364,13 +372,16 @@ def _assemble_deferred(filtered_tools, *, enabled: bool):
|
||||
and MCP tools survived filtering but no deferred set was recovered, raise
|
||||
rather than silently binding their full schemas to the model.
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import _is_mcp_tool, build_deferred_tool_setup
|
||||
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
|
||||
from deerflow.tools.mcp_metadata import is_mcp_tool
|
||||
|
||||
setup = build_deferred_tool_setup(filtered_tools, enabled=enabled)
|
||||
if enabled and not setup.deferred_names and any(_is_mcp_tool(t) for t in filtered_tools):
|
||||
deferred_setup = build_deferred_tool_setup(filtered_tools, enabled=enabled)
|
||||
if enabled and not deferred_setup.deferred_names and any(is_mcp_tool(t) for t in filtered_tools):
|
||||
raise RuntimeError("tool_search enabled and MCP tools survived policy filtering, but no deferred set was recovered — refusing to bind MCP schemas (fail-closed).")
|
||||
final_tools = list(filtered_tools) + ([setup.tool_search_tool] if setup.tool_search_tool else [])
|
||||
return final_tools, setup
|
||||
final_tools = list(filtered_tools)
|
||||
if deferred_setup.tool_search_tool:
|
||||
final_tools.append(deferred_setup.tool_search_tool)
|
||||
return final_tools, deferred_setup
|
||||
|
||||
|
||||
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
||||
|
||||
@ -28,11 +28,25 @@ from langchain_core.tools import InjectedToolCallId, tool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.tools.mcp_metadata import is_mcp_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RESULTS = 5 # Max tools returned per search
|
||||
|
||||
|
||||
def _compile_catalog_regex(pattern: str) -> re.Pattern[str]:
|
||||
"""Compile ``pattern`` case-insensitively, falling back to a literal match.
|
||||
|
||||
Search queries come from the model, so an invalid regex (e.g. an unbalanced
|
||||
paren) must degrade to a literal substring match rather than raise.
|
||||
"""
|
||||
try:
|
||||
return re.compile(pattern, re.IGNORECASE)
|
||||
except re.error:
|
||||
return re.compile(re.escape(pattern), re.IGNORECASE)
|
||||
|
||||
|
||||
# ── Catalog ──
|
||||
|
||||
|
||||
@ -56,22 +70,25 @@ class DeferredToolCatalog:
|
||||
return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16]
|
||||
|
||||
def search(self, query: str) -> list[BaseTool]:
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
if query.startswith("select:"):
|
||||
wanted = {n.strip() for n in query[7:].split(",")}
|
||||
return [t for t in self.tools if t.name in wanted][:MAX_RESULTS]
|
||||
|
||||
if query.startswith("+"):
|
||||
parts = query[1:].split(None, 1)
|
||||
if not parts:
|
||||
return [] # bare "+" with no required token — nothing to require
|
||||
required = parts[0].lower()
|
||||
candidates = [t for t in self.tools if required in t.name.lower()]
|
||||
if len(parts) > 1:
|
||||
candidates.sort(key=lambda t: _catalog_regex_score(parts[1], t), reverse=True)
|
||||
return candidates[:MAX_RESULTS]
|
||||
|
||||
try:
|
||||
regex = re.compile(query, re.IGNORECASE)
|
||||
except re.error:
|
||||
regex = re.compile(re.escape(query), re.IGNORECASE)
|
||||
regex = _compile_catalog_regex(query)
|
||||
scored: list[tuple[int, BaseTool]] = []
|
||||
for t in self.tools:
|
||||
searchable = f"{t.name} {t.description or ''}"
|
||||
@ -82,10 +99,7 @@ class DeferredToolCatalog:
|
||||
|
||||
|
||||
def _catalog_regex_score(pattern: str, t: BaseTool) -> int:
|
||||
try:
|
||||
regex = re.compile(pattern, re.IGNORECASE)
|
||||
except re.error:
|
||||
regex = re.compile(re.escape(pattern), re.IGNORECASE)
|
||||
regex = _compile_catalog_regex(pattern)
|
||||
return len(regex.findall(f"{t.name} {t.description or ''}"))
|
||||
|
||||
|
||||
@ -94,15 +108,25 @@ def _catalog_regex_score(pattern: str, t: BaseTool) -> int:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeferredToolSetup:
|
||||
"""Result of assembling deferred-tool support for one agent build.
|
||||
|
||||
The three fields move as a unit, so callers branch on ``tool_search_tool``:
|
||||
|
||||
- **Empty** ``(None, frozenset(), None)``: deferral is disabled, or no MCP
|
||||
tool survived policy filtering. Nothing is deferred — bind tools as-is.
|
||||
- **Populated**: ``tool_search_tool`` is appended to the agent's tools,
|
||||
``deferred_names`` are withheld from the model until promoted, and
|
||||
``catalog_hash`` scopes those promotions in graph state.
|
||||
|
||||
Invariant: ``tool_search_tool is None`` ⟺ ``deferred_names`` is empty ⟺
|
||||
``catalog_hash is None``.
|
||||
"""
|
||||
|
||||
tool_search_tool: BaseTool | None
|
||||
deferred_names: frozenset[str]
|
||||
catalog_hash: str | None
|
||||
|
||||
|
||||
def _is_mcp_tool(t: BaseTool) -> bool:
|
||||
return (getattr(t, "metadata", None) or {}).get("deerflow_mcp") is True
|
||||
|
||||
|
||||
def build_tool_search_tool(catalog: DeferredToolCatalog) -> BaseTool:
|
||||
catalog_hash = catalog.hash
|
||||
|
||||
@ -141,11 +165,17 @@ def build_deferred_tool_setup(filtered_tools: list[BaseTool], *, enabled: bool)
|
||||
|
||||
Must be called after skill/agent tool-policy filtering so the catalog never
|
||||
exposes a tool the current agent is not allowed to use.
|
||||
|
||||
Returns an empty setup (see :class:`DeferredToolSetup`) in two distinct
|
||||
cases: deferral is disabled, or it is enabled but no MCP tool survived
|
||||
filtering.
|
||||
"""
|
||||
if not enabled:
|
||||
# Deferral disabled: defer nothing; the model binds every tool as before.
|
||||
return DeferredToolSetup(None, frozenset(), None)
|
||||
deferred = [t for t in filtered_tools if _is_mcp_tool(t)]
|
||||
deferred = [t for t in filtered_tools if is_mcp_tool(t)]
|
||||
if not deferred:
|
||||
# Enabled, but no MCP tool to defer: same empty result, different reason.
|
||||
return DeferredToolSetup(None, frozenset(), None)
|
||||
catalog = DeferredToolCatalog(tuple(deferred))
|
||||
return DeferredToolSetup(build_tool_search_tool(catalog), catalog.names, catalog.hash)
|
||||
|
||||
29
backend/packages/harness/deerflow/tools/mcp_metadata.py
Normal file
29
backend/packages/harness/deerflow/tools/mcp_metadata.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Single source of truth for the MCP-tool metadata tag.
|
||||
|
||||
A tool is "MCP-sourced" when it carries the ``deerflow_mcp`` metadata flag.
|
||||
The tag is *written* where MCP tools are loaded (``tools.py``) and *read* by
|
||||
deferred-tool assembly (``tool_search.py``) and the agent build site
|
||||
(``agent.py``). Keeping the key, the tagger, and the predicate here means the
|
||||
magic string lives in exactly one place, and readers import a public predicate
|
||||
instead of a private cross-module helper.
|
||||
|
||||
This is a leaf module by design: it depends only on ``BaseTool`` so that any
|
||||
module (including the tool loader) can import it without an import cycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
MCP_TOOL_METADATA_KEY = "deerflow_mcp"
|
||||
|
||||
|
||||
def tag_mcp_tool(tool: BaseTool) -> BaseTool:
|
||||
"""Mark ``tool`` as MCP-sourced. Mutates in place and returns it for chaining."""
|
||||
tool.metadata = {**(tool.metadata or {}), MCP_TOOL_METADATA_KEY: True}
|
||||
return tool
|
||||
|
||||
|
||||
def is_mcp_tool(tool: BaseTool) -> bool:
|
||||
"""True when ``tool`` carries the MCP-source tag written by :func:`tag_mcp_tool`."""
|
||||
return (getattr(tool, "metadata", None) or {}).get(MCP_TOOL_METADATA_KEY) is True
|
||||
@ -7,6 +7,7 @@ from deerflow.config.app_config import AppConfig
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||
from deerflow.tools.mcp_metadata import tag_mcp_tool
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -132,7 +133,7 @@ def get_available_tools(
|
||||
# the deferred catalog + tool_search tool are assembled per
|
||||
# agent from the policy-filtered tool list.
|
||||
for t in mcp_tools:
|
||||
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
|
||||
tag_mcp_tool(t)
|
||||
except ImportError:
|
||||
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
|
||||
except Exception as e:
|
||||
|
||||
@ -54,6 +54,23 @@ def test_search_invalid_regex_falls_back_to_literal():
|
||||
assert cat.search("zzz(") == []
|
||||
|
||||
|
||||
def test_search_empty_query_returns_empty(catalog):
|
||||
# An empty / whitespace-only query is meaningless; rather than let the empty
|
||||
# regex match every tool, search() returns nothing so the model gets a clear
|
||||
# "no match" signal and re-queries instead of acting on noise.
|
||||
assert catalog.search("") == []
|
||||
assert catalog.search(" ") == []
|
||||
|
||||
|
||||
def test_search_bare_plus_returns_empty(catalog):
|
||||
# A "+" prefix with no required token is malformed model input. It must
|
||||
# return no matches, not raise IndexError on parts[0]. " + " strips to "+",
|
||||
# so it routes here too and must be handled the same way.
|
||||
assert catalog.search("+") == []
|
||||
assert catalog.search(" + ") == []
|
||||
assert catalog.search("+ ") == []
|
||||
|
||||
|
||||
def test_hash_stable_across_instances():
|
||||
c1 = DeferredToolCatalog((alpha_search, beta_translate))
|
||||
c2 = DeferredToolCatalog((beta_translate, alpha_search))
|
||||
|
||||
@ -20,6 +20,7 @@ from langchain_core.tools import tool as as_tool
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
|
||||
from deerflow.tools.mcp_metadata import tag_mcp_tool
|
||||
|
||||
|
||||
@as_tool
|
||||
@ -40,11 +41,6 @@ def mcp_other(x: str) -> str:
|
||||
return x
|
||||
|
||||
|
||||
def _tag(t):
|
||||
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
|
||||
return t
|
||||
|
||||
|
||||
def test_tool_search_promotes_into_next_turn():
|
||||
bound: list[list[str]] = []
|
||||
|
||||
@ -53,7 +49,7 @@ def test_tool_search_promotes_into_next_turn():
|
||||
bound.append([getattr(t, "name", None) for t in tools])
|
||||
return self
|
||||
|
||||
setup = build_deferred_tool_setup([active_tool, _tag(mcp_calc), _tag(mcp_other)], enabled=True)
|
||||
setup = build_deferred_tool_setup([active_tool, tag_mcp_tool(mcp_calc), tag_mcp_tool(mcp_other)], enabled=True)
|
||||
turn1 = AIMessage(content="", tool_calls=[{"name": "tool_search", "args": {"query": "select:mcp_calc"}, "id": "c1", "type": "tool_call"}])
|
||||
turn2 = AIMessage(content="done")
|
||||
model = RecordingModel(messages=iter([turn1, turn2]))
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from langchain_core.tools import tool as as_tool
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolCatalog, _is_mcp_tool, build_deferred_tool_setup, build_tool_search_tool
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolCatalog, build_deferred_tool_setup, build_tool_search_tool
|
||||
from deerflow.tools.mcp_metadata import is_mcp_tool, tag_mcp_tool
|
||||
|
||||
|
||||
@as_tool
|
||||
@ -16,18 +17,13 @@ def local_echo(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _tag_mcp(t):
|
||||
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
|
||||
return t
|
||||
|
||||
|
||||
def test_is_mcp_tool_reads_metadata():
|
||||
assert _is_mcp_tool(_tag_mcp(mcp_calc)) is True
|
||||
assert _is_mcp_tool(local_echo) is False
|
||||
assert is_mcp_tool(tag_mcp_tool(mcp_calc)) is True
|
||||
assert is_mcp_tool(local_echo) is False
|
||||
|
||||
|
||||
def test_setup_disabled_returns_empty():
|
||||
setup = build_deferred_tool_setup([_tag_mcp(mcp_calc), local_echo], enabled=False)
|
||||
setup = build_deferred_tool_setup([tag_mcp_tool(mcp_calc), local_echo], enabled=False)
|
||||
assert setup.tool_search_tool is None
|
||||
assert setup.deferred_names == frozenset()
|
||||
assert setup.catalog_hash is None
|
||||
@ -40,7 +36,7 @@ def test_setup_no_mcp_returns_empty():
|
||||
|
||||
|
||||
def test_setup_builds_from_mcp_survivors():
|
||||
setup = build_deferred_tool_setup([_tag_mcp(mcp_calc), local_echo], enabled=True)
|
||||
setup = build_deferred_tool_setup([tag_mcp_tool(mcp_calc), local_echo], enabled=True)
|
||||
assert setup.deferred_names == frozenset({"mcp_calc"})
|
||||
assert setup.tool_search_tool is not None
|
||||
assert setup.tool_search_tool.name == "tool_search"
|
||||
|
||||
@ -23,6 +23,7 @@ from deerflow.agents.middlewares.deferred_tool_filter_middleware import Deferred
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolSetup, build_deferred_tool_setup
|
||||
from deerflow.tools.mcp_metadata import tag_mcp_tool
|
||||
|
||||
|
||||
@as_tool
|
||||
@ -37,11 +38,6 @@ def mcp_secret(x: str) -> str:
|
||||
return x
|
||||
|
||||
|
||||
def _tag(t):
|
||||
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
|
||||
return t
|
||||
|
||||
|
||||
_BOUND: list[list[str]] = []
|
||||
|
||||
|
||||
@ -52,7 +48,7 @@ class _RecordingModel(GenericFakeChatModel):
|
||||
|
||||
|
||||
def _build_graph():
|
||||
filtered = [active_tool, _tag(mcp_secret)]
|
||||
filtered = [active_tool, tag_mcp_tool(mcp_secret)]
|
||||
setup = build_deferred_tool_setup(filtered, enabled=True)
|
||||
final = [*filtered, setup.tool_search_tool]
|
||||
model = _RecordingModel(messages=iter([AIMessage(content="done")] * 4))
|
||||
@ -107,18 +103,18 @@ def test_fail_closed_when_mcp_survives_without_setup(monkeypatch):
|
||||
lambda tools, *, enabled: DeferredToolSetup(None, frozenset(), None),
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="fail-closed"):
|
||||
agentmod._assemble_deferred([_tag(mcp_secret)], enabled=True)
|
||||
agentmod._assemble_deferred([tag_mcp_tool(mcp_secret)], enabled=True)
|
||||
|
||||
|
||||
def test_subagent_reentry_does_not_touch_lead_state():
|
||||
"""#2884: building a second (subagent) setup must not affect the lead's
|
||||
middleware. With no shared registry/ContextVar, the lead middleware depends
|
||||
only on its own deferred_names + the passed state."""
|
||||
lead_setup = build_deferred_tool_setup([active_tool, _tag(mcp_secret)], enabled=True)
|
||||
lead_setup = build_deferred_tool_setup([active_tool, tag_mcp_tool(mcp_secret)], enabled=True)
|
||||
mw = DeferredToolFilterMiddleware(lead_setup.deferred_names, lead_setup.catalog_hash)
|
||||
|
||||
# Simulate a subagent build re-entering tool assembly with its own setup.
|
||||
_ = build_deferred_tool_setup([_tag(mcp_secret)], enabled=True)
|
||||
_ = build_deferred_tool_setup([tag_mcp_tool(mcp_secret)], enabled=True)
|
||||
|
||||
class _Req:
|
||||
def __init__(self):
|
||||
@ -154,7 +150,7 @@ def test_policy_denied_mcp_yields_no_tool_search_end_to_end():
|
||||
tool_search (and does not fail-closed, because no MCP tool leaked through)."""
|
||||
from deerflow.agents.lead_agent import agent as agentmod
|
||||
|
||||
filtered = filter_tools_by_skill_allowed_tools([active_tool, _tag(mcp_secret)], [_make_skill(["active_tool"])])
|
||||
filtered = filter_tools_by_skill_allowed_tools([active_tool, tag_mcp_tool(mcp_secret)], [_make_skill(["active_tool"])])
|
||||
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
|
||||
|
||||
assert [t.name for t in final_tools] == ["active_tool"]
|
||||
@ -174,7 +170,7 @@ def test_tool_search_appended_after_policy_but_never_exposes_denied_tool():
|
||||
from deerflow.agents.lead_agent import agent as agentmod
|
||||
|
||||
allowed = ["active_tool", "mcp_secret"] # permits the MCP tool, does NOT list tool_search
|
||||
filtered = filter_tools_by_skill_allowed_tools([active_tool, _tag(mcp_secret)], [_make_skill(allowed)])
|
||||
filtered = filter_tools_by_skill_allowed_tools([active_tool, tag_mcp_tool(mcp_secret)], [_make_skill(allowed)])
|
||||
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
|
||||
|
||||
names = {t.name for t in final_tools}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user