mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix: promote deferred tools after tool_search returns schema (#1570)
* fix: promote matched tools from deferred registry after tool_search returns schema After tool_search returns a tool's full schema, the tool is promoted (removed from the deferred registry) so DeferredToolFilterMiddleware stops filtering it from bind_tools on subsequent LLM calls. Without this, deferred tools are permanently filtered — the LLM gets the schema from tool_search but can never invoke the tool because the middleware keeps stripping it. Fixes #1554 * test: add promote() and tool_search promotion tests Tests cover: - promote removes tools from registry - promote nonexistent/empty is no-op - search returns nothing after promote - middleware passes promoted tools through - tool_search auto-promotes matched tools (select + keyword) * fix: address review — lint blank line + empty registry guard - Add missing blank line between FakeRequest methods (E301) - Use 'if not registry' to handle empty registries consistently --------- Co-authored-by: d 🔹 <258577966+voidborne-d@users.noreply.github.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
ef58bb8d3c
commit
9bcdba6038
@ -51,6 +51,21 @@ class DeferredToolRegistry:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def promote(self, names: set[str]) -> None:
|
||||||
|
"""Remove tools from the deferred registry so they pass through the filter.
|
||||||
|
|
||||||
|
Called after tool_search returns a tool's schema — the LLM now knows
|
||||||
|
the full definition, so the DeferredToolFilterMiddleware should stop
|
||||||
|
stripping it from bind_tools on subsequent calls.
|
||||||
|
"""
|
||||||
|
if not names:
|
||||||
|
return
|
||||||
|
before = len(self._entries)
|
||||||
|
self._entries = [e for e in self._entries if e.name not in names]
|
||||||
|
promoted = before - len(self._entries)
|
||||||
|
if promoted:
|
||||||
|
logger.debug(f"Promoted {promoted} tool(s) from deferred to active: {names}")
|
||||||
|
|
||||||
def search(self, query: str) -> list[BaseTool]:
|
def search(self, query: str) -> list[BaseTool]:
|
||||||
"""Search deferred tools by regex pattern against name + description.
|
"""Search deferred tools by regex pattern against name + description.
|
||||||
|
|
||||||
@ -160,7 +175,7 @@ def tool_search(query: str) -> str:
|
|||||||
Matched tool definitions as JSON array.
|
Matched tool definitions as JSON array.
|
||||||
"""
|
"""
|
||||||
registry = get_deferred_registry()
|
registry = get_deferred_registry()
|
||||||
if registry is None:
|
if not registry:
|
||||||
return "No deferred tools available."
|
return "No deferred tools available."
|
||||||
|
|
||||||
matched_tools = registry.search(query)
|
matched_tools = registry.search(query)
|
||||||
@ -171,4 +186,8 @@ def tool_search(query: str) -> str:
|
|||||||
# This is model-agnostic: all LLMs understand this standard schema.
|
# This is model-agnostic: all LLMs understand this standard schema.
|
||||||
tool_defs = [convert_to_openai_function(t) for t in matched_tools[:MAX_RESULTS]]
|
tool_defs = [convert_to_openai_function(t) for t in matched_tools[:MAX_RESULTS]]
|
||||||
|
|
||||||
|
# Promote matched tools so the DeferredToolFilterMiddleware stops filtering
|
||||||
|
# them from bind_tools — the LLM now has the full schema and can invoke them.
|
||||||
|
registry.promote({t.name for t in matched_tools[:MAX_RESULTS]})
|
||||||
|
|
||||||
return json.dumps(tool_defs, indent=2, ensure_ascii=False)
|
return json.dumps(tool_defs, indent=2, ensure_ascii=False)
|
||||||
|
|||||||
@ -392,3 +392,120 @@ class TestDeferredToolFilterMiddleware:
|
|||||||
|
|
||||||
# dict_tool has no .name attr → getattr returns None → not in deferred_names → kept
|
# dict_tool has no .name attr → getattr returns None → not in deferred_names → kept
|
||||||
assert len(filtered.tools) == 2
|
assert len(filtered.tools) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── Promote Tests ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeferredToolRegistryPromote:
|
||||||
|
def test_promote_removes_tools(self, registry):
|
||||||
|
assert len(registry) == 6
|
||||||
|
registry.promote({"github_create_issue", "slack_send_message"})
|
||||||
|
assert len(registry) == 4
|
||||||
|
remaining = {e.name for e in registry.entries}
|
||||||
|
assert "github_create_issue" not in remaining
|
||||||
|
assert "slack_send_message" not in remaining
|
||||||
|
assert "github_list_repos" in remaining
|
||||||
|
|
||||||
|
def test_promote_nonexistent_is_noop(self, registry):
|
||||||
|
assert len(registry) == 6
|
||||||
|
registry.promote({"nonexistent_tool"})
|
||||||
|
assert len(registry) == 6
|
||||||
|
|
||||||
|
def test_promote_empty_set_is_noop(self, registry):
|
||||||
|
assert len(registry) == 6
|
||||||
|
registry.promote(set())
|
||||||
|
assert len(registry) == 6
|
||||||
|
|
||||||
|
def test_promote_all(self, registry):
|
||||||
|
all_names = {e.name for e in registry.entries}
|
||||||
|
registry.promote(all_names)
|
||||||
|
assert len(registry) == 0
|
||||||
|
|
||||||
|
def test_search_after_promote_excludes_promoted(self, registry):
|
||||||
|
"""After promoting github tools, searching 'github' returns nothing."""
|
||||||
|
registry.promote({"github_create_issue", "github_list_repos"})
|
||||||
|
results = registry.search("github")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_filter_after_promote_passes_through(self, registry):
|
||||||
|
"""After tool_search promotes a tool, the middleware lets it through."""
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
# Clear any mock entries
|
||||||
|
mock_keys = [
|
||||||
|
"deerflow.agents",
|
||||||
|
"deerflow.agents.middlewares",
|
||||||
|
"deerflow.agents.middlewares.deferred_tool_filter_middleware",
|
||||||
|
]
|
||||||
|
for key in mock_keys:
|
||||||
|
if isinstance(sys.modules.get(key), MagicMock):
|
||||||
|
del sys.modules[key]
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||||
|
|
||||||
|
set_deferred_registry(registry)
|
||||||
|
middleware = DeferredToolFilterMiddleware()
|
||||||
|
|
||||||
|
target_tool = registry.entries[0].tool # github_create_issue
|
||||||
|
active_tool = _make_mock_tool("my_active_tool", "Active")
|
||||||
|
|
||||||
|
class FakeRequest:
|
||||||
|
def __init__(self, tools):
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
|
def override(self, **kwargs):
|
||||||
|
return FakeRequest(kwargs.get("tools", self.tools))
|
||||||
|
|
||||||
|
# Before promote: deferred tool is filtered
|
||||||
|
request = FakeRequest(tools=[active_tool, target_tool])
|
||||||
|
filtered = middleware._filter_tools(request)
|
||||||
|
assert len(filtered.tools) == 1
|
||||||
|
assert filtered.tools[0].name == "my_active_tool"
|
||||||
|
|
||||||
|
# Promote the tool
|
||||||
|
registry.promote({"github_create_issue"})
|
||||||
|
|
||||||
|
# After promote: tool passes through the filter
|
||||||
|
request2 = FakeRequest(tools=[active_tool, target_tool])
|
||||||
|
filtered2 = middleware._filter_tools(request2)
|
||||||
|
assert len(filtered2.tools) == 2
|
||||||
|
tool_names = {t.name for t in filtered2.tools}
|
||||||
|
assert "github_create_issue" in tool_names
|
||||||
|
assert "my_active_tool" in tool_names
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolSearchPromotion:
|
||||||
|
def test_tool_search_promotes_matched_tools(self, registry):
|
||||||
|
"""tool_search should promote matched tools so they become callable."""
|
||||||
|
from deerflow.tools.builtins.tool_search import tool_search
|
||||||
|
|
||||||
|
set_deferred_registry(registry)
|
||||||
|
assert len(registry) == 6
|
||||||
|
|
||||||
|
# Search for github tools — should return schemas AND promote them
|
||||||
|
result = tool_search.invoke({"query": "select:github_create_issue"})
|
||||||
|
parsed = json.loads(result)
|
||||||
|
assert len(parsed) == 1
|
||||||
|
assert parsed[0]["name"] == "github_create_issue"
|
||||||
|
|
||||||
|
# The tool should now be promoted (removed from registry)
|
||||||
|
assert len(registry) == 5
|
||||||
|
remaining = {e.name for e in registry.entries}
|
||||||
|
assert "github_create_issue" not in remaining
|
||||||
|
|
||||||
|
def test_tool_search_keyword_promotes_all_matches(self, registry):
|
||||||
|
"""Keyword search promotes all matched tools."""
|
||||||
|
from deerflow.tools.builtins.tool_search import tool_search
|
||||||
|
|
||||||
|
set_deferred_registry(registry)
|
||||||
|
result = tool_search.invoke({"query": "slack"})
|
||||||
|
parsed = json.loads(result)
|
||||||
|
assert len(parsed) == 2
|
||||||
|
|
||||||
|
# Both slack tools promoted
|
||||||
|
remaining = {e.name for e in registry.entries}
|
||||||
|
assert "slack_send_message" not in remaining
|
||||||
|
assert "slack_list_channels" not in remaining
|
||||||
|
assert len(registry) == 4
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user