diff --git a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py index 0b161152c..9cfc4400f 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py @@ -31,6 +31,8 @@ _DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls _DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls _DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit +_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type +_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]: @@ -125,8 +127,14 @@ def _hash_tool_calls(tool_calls: list[dict]) -> str: _WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far." +_TOOL_FREQ_WARNING_MSG = ( + "[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far." +) + _HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far." +_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far." + class LoopDetectionMiddleware(AgentMiddleware[AgentState]): """Detects and breaks repetitive tool call loops. @@ -140,6 +148,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): Default: 20. max_tracked_threads: Maximum number of threads to track before evicting the least recently used. Default: 100. + tool_freq_warn: Number of calls to the same tool *type* (regardless + of arguments) before injecting a frequency warning. Catches + cross-file read loops that hash-based detection misses. + Default: 30. + tool_freq_hard_limit: Number of calls to the same tool type before + forcing a stop. Default: 50. """ def __init__( @@ -148,16 +162,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): hard_limit: int = _DEFAULT_HARD_LIMIT, window_size: int = _DEFAULT_WINDOW_SIZE, max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS, + tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN, + tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT, ): super().__init__() self.warn_threshold = warn_threshold self.hard_limit = hard_limit self.window_size = window_size self.max_tracked_threads = max_tracked_threads + self.tool_freq_warn = tool_freq_warn + self.tool_freq_hard_limit = tool_freq_hard_limit self._lock = threading.Lock() # Per-thread tracking using OrderedDict for LRU eviction self._history: OrderedDict[str, list[str]] = OrderedDict() self._warned: dict[str, set[str]] = defaultdict(set) + # Per-thread, per-tool-type cumulative call counts + self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) def _get_thread_id(self, runtime: Runtime) -> str: """Extract thread_id from runtime context for per-thread tracking.""" @@ -174,11 +195,19 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): while len(self._history) > self.max_tracked_threads: evicted_id, _ = self._history.popitem(last=False) self._warned.pop(evicted_id, None) + self._tool_freq.pop(evicted_id, None) + self._tool_freq_warned.pop(evicted_id, None) logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id) def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]: """Track tool calls and check for loops. + Two detection layers: + 1. **Hash-based** (existing): catches identical tool call sets. + 2. **Frequency-based** (new): catches the same *tool type* being + called many times with varying arguments (e.g. ``read_file`` + on 40 different files). + Returns: (warning_message_or_none, should_hard_stop) """ @@ -213,6 +242,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): count = history.count(call_hash) tool_names = [tc.get("name", "?") for tc in tool_calls] + # --- Layer 1: hash-based (identical call sets) --- if count >= self.hard_limit: logger.error( "Loop hard limit reached — forcing stop", @@ -239,8 +269,40 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): }, ) return _WARNING_MSG, False - # Warning already injected for this hash — suppress - return None, False + + # --- Layer 2: per-tool-type frequency --- + freq = self._tool_freq[thread_id] + for tc in tool_calls: + name = tc.get("name", "") + if not name: + continue + freq[name] += 1 + tc_count = freq[name] + + if tc_count >= self.tool_freq_hard_limit: + logger.error( + "Tool frequency hard limit reached — forcing stop", + extra={ + "thread_id": thread_id, + "tool_name": name, + "count": tc_count, + }, + ) + return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True + + if tc_count >= self.tool_freq_warn: + warned = self._tool_freq_warned[thread_id] + if name not in warned: + warned.add(name) + logger.warning( + "Tool frequency warning — too many calls to same tool type", + extra={ + "thread_id": thread_id, + "tool_name": name, + "count": tc_count, + }, + ) + return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False return None, False @@ -271,7 +333,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): stripped_msg = last_msg.model_copy( update={ "tool_calls": [], - "content": self._append_text(last_msg.content, _HARD_STOP_MSG), + "content": self._append_text(last_msg.content, warning), } ) return {"messages": [stripped_msg]} @@ -301,6 +363,10 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): if thread_id: self._history.pop(thread_id, None) self._warned.pop(thread_id, None) + self._tool_freq.pop(thread_id, None) + self._tool_freq_warned.pop(thread_id, None) else: self._history.clear() self._warned.clear() + self._tool_freq.clear() + self._tool_freq_warned.clear() diff --git a/backend/tests/test_loop_detection_middleware.py b/backend/tests/test_loop_detection_middleware.py index c40e218ac..9accd60d3 100644 --- a/backend/tests/test_loop_detection_middleware.py +++ b/backend/tests/test_loop_detection_middleware.py @@ -280,6 +280,8 @@ class TestLoopDetection: mw._apply(_make_state(tool_calls=call), runtime_new) assert "thread-0" not in mw._history + assert "thread-0" not in mw._tool_freq + assert "thread-0" not in mw._tool_freq_warned assert "thread-new" in mw._history assert len(mw._history) == 3 @@ -410,3 +412,188 @@ class TestHardStopWithListContent: assert isinstance(msg.content, str) assert msg.content.startswith("thinking...") assert _HARD_STOP_MSG in msg.content + + +class TestToolFrequencyDetection: + """Tests for per-tool-type frequency detection (Layer 2). + + This catches the case where an agent calls the same tool type many times + with *different* arguments (e.g. read_file on 40 different files), which + bypasses hash-based detection. + """ + + def _read_call(self, path): + return {"name": "read_file", "id": f"call_read_{path}", "args": {"path": path}} + + def test_below_freq_warn_returns_none(self): + mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10) + runtime = _make_runtime() + + for i in range(4): + result = mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) + assert result is None + + def test_freq_warn_at_threshold(self): + mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10) + runtime = _make_runtime() + + for i in range(4): + mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) + + # 5th call to read_file (different file each time) triggers freq warning + result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime) + assert result is not None + msg = result["messages"][0] + assert isinstance(msg, HumanMessage) + assert "read_file" in msg.content + assert "LOOP DETECTED" in msg.content + + def test_freq_warn_only_injected_once(self): + mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) + runtime = _make_runtime() + + for i in range(2): + mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) + + # 3rd triggers warning + result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) + assert result is not None + assert "LOOP DETECTED" in result["messages"][0].content + + # 4th should not re-warn (already warned for read_file) + result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime) + assert result is None + + def test_freq_hard_stop_at_limit(self): + mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6) + runtime = _make_runtime() + + for i in range(5): + mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) + + # 6th call triggers hard stop + result = mw._apply(_make_state(tool_calls=[self._read_call("/file_5.py")]), runtime) + assert result is not None + msg = result["messages"][0] + assert isinstance(msg, AIMessage) + assert msg.tool_calls == [] + assert "FORCED STOP" in msg.content + assert "read_file" in msg.content + + def test_different_tools_tracked_independently(self): + """read_file and bash should have independent frequency counters.""" + mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) + runtime = _make_runtime() + + # 2 read_file calls + for i in range(2): + mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) + + # 2 bash calls — should not trigger (bash count = 2, read_file count = 2) + for i in range(2): + result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime) + assert result is None + + # 3rd read_file triggers (read_file count = 3) + result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) + assert result is not None + assert "read_file" in result["messages"][0].content + + def test_freq_reset_clears_state(self): + mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) + runtime = _make_runtime() + + for i in range(2): + mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) + + mw.reset() + + # After reset, count restarts — should not trigger + result = mw._apply(_make_state(tool_calls=[self._read_call("/file_new.py")]), runtime) + assert result is None + + def test_freq_reset_per_thread_clears_only_target(self): + """reset(thread_id=...) should clear frequency state for that thread only.""" + mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) + runtime_a = _make_runtime("thread-A") + runtime_b = _make_runtime("thread-B") + + # 2 calls on each thread + for i in range(2): + mw._apply(_make_state(tool_calls=[self._read_call(f"/a_{i}.py")]), runtime_a) + mw._apply(_make_state(tool_calls=[self._read_call(f"/b_{i}.py")]), runtime_b) + + # Reset only thread-A + mw.reset(thread_id="thread-A") + + assert "thread-A" not in mw._tool_freq + assert "thread-A" not in mw._tool_freq_warned + + # thread-B state should still be intact — 3rd call triggers warn + result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b) + assert result is not None + assert "LOOP DETECTED" in result["messages"][0].content + + # thread-A restarted from 0 — should not trigger + result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a) + assert result is None + + def test_freq_per_thread_isolation(self): + """Frequency counts should be independent per thread.""" + mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) + runtime_a = _make_runtime("thread-A") + runtime_b = _make_runtime("thread-B") + + # 2 calls on thread A + for i in range(2): + mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime_a) + + # 2 calls on thread B — should NOT push thread A over threshold + for i in range(2): + mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b) + + # 3rd call on thread A — triggers (count=3 for thread A only) + result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a) + assert result is not None + assert "LOOP DETECTED" in result["messages"][0].content + + def test_multi_tool_single_response_counted(self): + """When a single response has multiple tool calls, each is counted.""" + mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10) + runtime = _make_runtime() + + # Response 1: 2 read_file calls → count = 2 + call = [self._read_call("/a.py"), self._read_call("/b.py")] + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is None + + # Response 2: 2 more → count = 4 + call = [self._read_call("/c.py"), self._read_call("/d.py")] + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is None + + # Response 3: 1 more → count = 5 → triggers warn + result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime) + assert result is not None + assert "read_file" in result["messages"][0].content + + def test_hash_detection_takes_priority(self): + """Hash-based hard stop fires before frequency check for identical calls.""" + mw = LoopDetectionMiddleware( + warn_threshold=2, + hard_limit=3, + tool_freq_warn=100, + tool_freq_hard_limit=200, + ) + runtime = _make_runtime() + call = [self._read_call("/same_file.py")] + + for _ in range(2): + mw._apply(_make_state(tool_calls=call), runtime) + + # 3rd identical call → hash hard_limit=3 fires (not freq) + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is not None + msg = result["messages"][0] + assert isinstance(msg, AIMessage) + assert _HARD_STOP_MSG in msg.content