feat(loop-detection): make loop detection configurable with per-tool frequency overrides (#2711)

* Make loop detection configurable

Expose LoopDetectionMiddleware thresholds through config.yaml while preserving existing defaults and allowing the middleware to be disabled.

Refs bytedance/deer-flow#2517

* feat(loop-detection): add per-tool tool_freq_overrides to Phase 1

Adds ToolFreqOverride model and tool_freq_overrides field to
LoopDetectionConfig, wires it through LoopDetectionMiddleware, and
documents the option in config.example.yaml.

Resolves the gap flagged in the #2586 review: without per-tool overrides,
users hit by #2510/#2511 (RNA-seq workflows exceeding the bash hard limit)
had no way to raise thresholds for one tool without loosening the global
limit for every tool.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* docs(loop-detection): document tool_freq_overrides in LoopDetectionMiddleware docstring

Add the missing Args entry for tool_freq_overrides, explaining the
(warn, hard_limit) tuple structure and how per-tool thresholds supersede
the global tool_freq_warn / tool_freq_hard_limit for named tools.
Also run ruff format on the three files flagged by the lint check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(loop-detection): validate LoopDetectionMiddleware __init__ params eagerly

Raise clear ValueError at construction time instead of crashing at
unpack-time inside _track_and_check when bad values are passed:
- tool_freq_overrides: must be 2-tuples of positive ints with hard_limit >= warn
- scalar thresholds: warn_threshold, hard_limit, tool_freq_warn,
  tool_freq_hard_limit must be >= 1 and hard limits must >= their warn pairs
- window_size, max_tracked_threads must be >= 1

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(test): isolate credential loader directory-path test from real ~/.claude

The test didn't monkeypatch HOME, so on any machine with real Claude Code
credentials at ~/.claude/.credentials.json the function fell through to
those credentials and the assertion failed. Adding HOME redirect ensures
the default credential path doesn't exist during the test.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style(test): add blank lines after import pytest in TestInitValidation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(loop-detection): collapse dual validation to LoopDetectionConfig

Modifications
  - LoopDetectionMiddleware.__init__: stripped of all ValueError raises;
    becomes a plain field-assignment constructor.
  - LoopDetectionMiddleware.from_config: classmethod that builds the
    middleware from a Pydantic-validated LoopDetectionConfig and handles
    the ToolFreqOverride -> tuple[int, int] conversion.
  - agents/factory.py: SDK construction routed through
    LoopDetectionMiddleware.from_config(LoopDetectionConfig()) so the
    defaults path is Pydantic-validated too.
  - agents/lead_agent/agent.py: uses from_config instead of unpacking
    config fields by hand.
  - tests/test_loop_detection_middleware.py: deleted TestInitValidation
    (16 methods exercising the removed __init__ checks); added
    TestFromConfig (4 tests: scalar field mapping, override tuple
    conversion, empty overrides, behavioral smoke test).

Result: one validation layer (Pydantic), zero duplication, no __new__
hacks. Both production construction sites flow through LoopDetectionConfig.

Test results
  make test   -> 2977 passed, 18 skipped, 0 failed (137s)
  make format -> All checks passed; 411 files left unchanged

* feat(agents): make loop_detection configurable in create_deerflow_agent

Adds a `loop_detection: bool | AgentMiddleware = True` field to
RuntimeFeatures, mirroring the existing pattern used by `sandbox`,
`memory`, and `vision`. SDK users can now disable LoopDetectionMiddleware
or replace it with a custom instance built from their own
LoopDetectionConfig — e.g.
`LoopDetectionMiddleware.from_config(my_cfg)` — instead of being stuck
with the hardcoded defaults previously installed by the SDK factory.

The lead-agent path (which already reads AppConfig.loop_detection) is
unchanged, and the default `True` preserves prior always-on behavior for
all existing callers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

---------

Co-authored-by: knight0940 <631532668@qq.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: Amorend <142649913+knight0940@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Tao Liu 2026-05-07 16:15:15 +08:00 committed by GitHub
parent 27559f3675
commit daa3ffc29b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 406 additions and 12 deletions

View File

@ -173,7 +173,7 @@ def _assemble_from_features(
9. MemoryMiddleware (memory feature)
10. ViewImageMiddleware (vision feature)
11. SubagentLimitMiddleware (subagent feature)
12. LoopDetectionMiddleware (always)
12. LoopDetectionMiddleware (loop_detection feature)
13. ClarificationMiddleware (always last)
Two-phase ordering:
@ -272,10 +272,15 @@ def _assemble_from_features(
extra_tools.append(task_tool)
# --- [12] LoopDetection (always) ---
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
# --- [12] LoopDetection ---
if feat.loop_detection is not False:
if isinstance(feat.loop_detection, AgentMiddleware):
chain.append(feat.loop_detection)
else:
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.loop_detection_config import LoopDetectionConfig
chain.append(LoopDetectionMiddleware())
chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
# --- [13] Clarification (always last among built-ins) ---
chain.append(ClarificationMiddleware())

View File

@ -31,6 +31,7 @@ class RuntimeFeatures:
vision: bool | AgentMiddleware = False
auto_title: bool | AgentMiddleware = False
guardrail: Literal[False] | AgentMiddleware = False
loop_detection: bool | AgentMiddleware = True
# ---------------------------------------------------------------------------

View File

@ -299,7 +299,9 @@ def _build_middlewares(
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
middlewares.append(LoopDetectionMiddleware())
loop_detection_config = resolved_app_config.loop_detection
if loop_detection_config.enabled:
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
# Inject custom middlewares before ClarificationMiddleware
if custom_middlewares:

View File

@ -12,18 +12,23 @@ Detection strategy:
response so the agent is forced to produce a final text answer.
"""
from __future__ import annotations
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import override
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
if TYPE_CHECKING:
from deerflow.config.loop_detection_config import LoopDetectionConfig
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
@ -139,6 +144,9 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
construct via :meth:`from_config` to ensure values pass Pydantic validation.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
@ -154,6 +162,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50.
tool_freq_overrides: Per-tool overrides for frequency thresholds,
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
that specific tool. Tools not listed here fall back to the global
thresholds. Useful for raising limits on intentionally
high-frequency tools (e.g. ``bash`` in batch pipelines) without
weakening protection on all other tools. Default: ``None``
(no overrides).
"""
def __init__(
@ -164,6 +180,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
):
super().__init__()
self.warn_threshold = warn_threshold
@ -172,14 +189,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
@classmethod
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
"""Construct from a Pydantic-validated config, trusting its validation."""
return cls(
warn_threshold=config.warn_threshold,
hard_limit=config.hard_limit,
window_size=config.window_size,
max_tracked_threads=config.max_tracked_threads,
tool_freq_warn=config.tool_freq_warn,
tool_freq_hard_limit=config.tool_freq_hard_limit,
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
@ -279,7 +308,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
freq[name] += 1
tc_count = freq[name]
if tc_count >= self.tool_freq_hard_limit:
if name in self._tool_freq_overrides:
eff_warn, eff_hard = self._tool_freq_overrides[name]
else:
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
if tc_count >= eff_hard:
logger.error(
"Tool frequency hard limit reached — forcing stop",
extra={
@ -290,7 +324,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
)
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= self.tool_freq_warn:
if tc_count >= eff_warn:
warned = self._tool_freq_warned[thread_id]
if name not in warned:
warned.add(name)

View File

@ -1,5 +1,6 @@
from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config
from .loop_detection_config import LoopDetectionConfig
from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig
@ -20,6 +21,7 @@ __all__ = [
"SkillsConfig",
"ExtensionsConfig",
"get_extensions_config",
"LoopDetectionConfig",
"MemoryConfig",
"get_memory_config",
"get_tracing_config",

View File

@ -15,6 +15,7 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
@ -100,6 +101,7 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")

View File

@ -0,0 +1,73 @@
"""Configuration for loop detection middleware."""
from pydantic import BaseModel, Field, model_validator
class ToolFreqOverride(BaseModel):
"""Per-tool frequency threshold override.
Can be higher or lower than the global defaults. Commonly used to raise
thresholds for high-frequency tools like bash in batch workflows (e.g.
RNA-seq pipelines) without weakening protection on every other tool.
"""
warn: int = Field(ge=1)
hard_limit: int = Field(ge=1)
@model_validator(mode="after")
def _validate(self) -> "ToolFreqOverride":
if self.hard_limit < self.warn:
raise ValueError("hard_limit must be >= warn")
return self
class LoopDetectionConfig(BaseModel):
"""Configuration for repetitive tool-call loop detection."""
enabled: bool = Field(
default=True,
description="Whether to enable repetitive tool-call loop detection",
)
warn_threshold: int = Field(
default=3,
ge=1,
description="Number of identical tool-call sets before injecting a warning",
)
hard_limit: int = Field(
default=5,
ge=1,
description="Number of identical tool-call sets before forcing a stop",
)
window_size: int = Field(
default=20,
ge=1,
description="Number of recent tool-call sets to track per thread",
)
max_tracked_threads: int = Field(
default=100,
ge=1,
description="Maximum number of thread histories to keep in memory",
)
tool_freq_warn: int = Field(
default=30,
ge=1,
description="Number of calls to the same tool type before injecting a frequency warning",
)
tool_freq_hard_limit: int = Field(
default=50,
ge=1,
description="Number of calls to the same tool type before forcing a stop",
)
tool_freq_overrides: dict[str, ToolFreqOverride] = Field(
default_factory=dict,
description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."),
)
@model_validator(mode="after")
def validate_thresholds(self) -> "LoopDetectionConfig":
"""Ensure hard stop cannot happen before the warning threshold."""
if self.hard_limit < self.warn_threshold:
raise ValueError("hard_limit must be greater than or equal to warn_threshold")
if self.tool_freq_hard_limit < self.tool_freq_warn:
raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn")
return self

View File

@ -192,6 +192,7 @@ def test_agent_features_defaults():
assert f.vision is False
assert f.auto_title is False
assert f.guardrail is False
assert f.loop_detection is True
# ---------------------------------------------------------------------------
@ -630,6 +631,51 @@ def test_loop_detection_before_clarification(mock_create_agent):
assert loop_idx == clar_idx - 1
# ---------------------------------------------------------------------------
# 30b. loop_detection=False skips LoopDetectionMiddleware
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_disabled(mock_create_agent):
mock_create_agent.return_value = MagicMock()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=False),
)
call_kwargs = mock_create_agent.call_args[1]
mw_types = [type(m).__name__ for m in call_kwargs["middleware"]]
assert "LoopDetectionMiddleware" not in mw_types
# ---------------------------------------------------------------------------
# 30c. loop_detection=<custom AgentMiddleware> replaces the default
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_custom_middleware(mock_create_agent):
from langchain.agents.middleware import AgentMiddleware as AM
mock_create_agent.return_value = MagicMock()
class MyLoopDetection(AM):
pass
custom = MyLoopDetection()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=custom),
)
call_kwargs = mock_create_agent.call_args[1]
middleware = call_kwargs["middleware"]
assert custom in middleware
mw_types = [type(m).__name__ for m in middleware]
# Default LoopDetectionMiddleware must not also appear.
assert "LoopDetectionMiddleware" not in mw_types
# Custom replacement still sits immediately before ClarificationMiddleware.
assert mw_types[-1] == "ClarificationMiddleware"
assert mw_types[-2] == "MyLoopDetection"
# ---------------------------------------------------------------------------
# 31. plan_mode=True adds TodoMiddleware
# ---------------------------------------------------------------------------

View File

@ -85,6 +85,8 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch):
def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch):
_clear_claude_code_env(monkeypatch)
# Redirect HOME so the default ~/.claude/.credentials.json doesn't exist
monkeypatch.setenv("HOME", str(tmp_path))
cred_dir = tmp_path / "claude-creds-dir"
cred_dir.mkdir()
monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir))

View File

@ -8,17 +8,20 @@ from unittest.mock import MagicMock
import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig:
return AppConfig(
models=models,
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"),
loop_detection=loop_detection or LoopDetectionConfig(),
)
@ -340,6 +343,59 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
assert middlewares[0] == "base-middleware"
def test_build_middlewares_uses_loop_detection_config(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(
warn_threshold=7,
hard_limit=9,
window_size=30,
max_tracked_threads=40,
tool_freq_warn=50,
tool_freq_hard_limit=60,
),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware))
assert loop_detection.warn_threshold == 7
assert loop_detection.hard_limit == 9
assert loop_detection.window_size == 30
assert loop_detection.max_tracked_threads == 40
assert loop_detection.tool_freq_warn == 50
assert loop_detection.tool_freq_hard_limit == 60
def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(enabled=False),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares)
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")

View File

@ -0,0 +1,72 @@
"""Tests for loop detection configuration."""
import pytest
from deerflow.config.loop_detection_config import LoopDetectionConfig
class TestLoopDetectionConfig:
def test_defaults_match_middleware_defaults(self):
config = LoopDetectionConfig()
assert config.enabled is True
assert config.warn_threshold == 3
assert config.hard_limit == 5
assert config.window_size == 20
assert config.max_tracked_threads == 100
assert config.tool_freq_warn == 30
assert config.tool_freq_hard_limit == 50
def test_accepts_custom_values(self):
config = LoopDetectionConfig(
enabled=False,
warn_threshold=10,
hard_limit=20,
window_size=50,
max_tracked_threads=200,
tool_freq_warn=60,
tool_freq_hard_limit=80,
)
assert config.enabled is False
assert config.warn_threshold == 10
assert config.hard_limit == 20
assert config.window_size == 50
assert config.max_tracked_threads == 200
assert config.tool_freq_warn == 60
assert config.tool_freq_hard_limit == 80
def test_rejects_zero_thresholds(self):
with pytest.raises(ValueError):
LoopDetectionConfig(warn_threshold=0)
with pytest.raises(ValueError):
LoopDetectionConfig(hard_limit=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_warn=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_hard_limit=0)
def test_rejects_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(warn_threshold=5, hard_limit=4)
def test_rejects_tool_freq_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="tool_freq_hard_limit"):
LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4)
def test_tool_freq_override_valid(self):
config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}})
override = config.tool_freq_overrides["bash"]
assert override.warn == 150
assert override.hard_limit == 300
def test_tool_freq_override_rejects_zero_warn(self):
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}})
def test_tool_freq_override_rejects_hard_limit_below_warn(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}})

View File

@ -648,6 +648,37 @@ class TestToolFrequencyDetection:
assert result is not None
assert "read_file" in result["messages"][0].content
def test_override_tool_uses_override_thresholds(self):
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
mw = LoopDetectionMiddleware(
tool_freq_warn=5,
tool_freq_hard_limit=10,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
# 10 bash calls — would hit global hard_limit=10, but bash override is 100
for i in range(10):
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None, f"unexpected trigger on call {i + 1}"
def test_non_override_tool_falls_back_to_global(self):
"""A tool NOT in tool_freq_overrides uses the global warn/hard_limit."""
mw = LoopDetectionMiddleware(
tool_freq_warn=3,
tool_freq_hard_limit=6,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd read_file call hits global warn=3 (read_file has no override)
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
def test_hash_detection_takes_priority(self):
"""Hash-based hard stop fires before frequency check for identical calls."""
mw = LoopDetectionMiddleware(
@ -668,3 +699,48 @@ class TestToolFrequencyDetection:
msg = result["messages"][0]
assert isinstance(msg, AIMessage)
assert _HARD_STOP_MSG in msg.content
class TestFromConfig:
"""Tests for LoopDetectionMiddleware.from_config — the sole validated construction path."""
@staticmethod
def _config(**kwargs):
from deerflow.config.loop_detection_config import LoopDetectionConfig
return LoopDetectionConfig(**kwargs)
def test_scalar_fields_mapped(self):
config = self._config(
warn_threshold=4,
hard_limit=8,
window_size=15,
max_tracked_threads=50,
tool_freq_warn=20,
tool_freq_hard_limit=40,
)
mw = LoopDetectionMiddleware.from_config(config)
assert mw.warn_threshold == 4
assert mw.hard_limit == 8
assert mw.window_size == 15
assert mw.max_tracked_threads == 50
assert mw.tool_freq_warn == 20
assert mw.tool_freq_hard_limit == 40
def test_overrides_converted_to_tuples(self):
config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}})
mw = LoopDetectionMiddleware.from_config(config)
assert mw._tool_freq_overrides == {"bash": (50, 100)}
def test_empty_overrides(self):
mw = LoopDetectionMiddleware.from_config(self._config())
assert mw._tool_freq_overrides == {}
def test_constructed_middleware_detects_loops(self):
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content

View File

@ -15,7 +15,7 @@
# ============================================================================
# Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 8
config_version: 9
# ============================================================================
# Logging
@ -506,6 +506,29 @@ tools:
tool_search:
enabled: false
# ============================================================================
# Loop Detection Configuration
# ============================================================================
# Detect and interrupt repeated identical tool-call loops.
# Frequency thresholds are safety limits for repeated use of the same tool type.
loop_detection:
enabled: true
warn_threshold: 3
hard_limit: 5
window_size: 20
max_tracked_threads: 100
tool_freq_warn: 30
tool_freq_hard_limit: 50
# Per-tool overrides for tool_freq_warn / tool_freq_hard_limit. Values can be
# higher or lower than the global defaults. Commonly used to raise thresholds
# for high-frequency tools like bash in batch workflows (e.g. RNA-seq pipelines)
# without weakening protection on every other tool.
# tool_freq_overrides:
# bash:
# warn: 150
# hard_limit: 300
# ============================================================================
# Sandbox Configuration
# ============================================================================