From daa3ffc29b45f26c48531696b25d874884ec9e74 Mon Sep 17 00:00:00 2001 From: Tao Liu Date: Thu, 7 May 2026 16:15:15 +0800 Subject: [PATCH] feat(loop-detection): make loop detection configurable with per-tool frequency overrides (#2711) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make loop detection configurable Expose LoopDetectionMiddleware thresholds through config.yaml while preserving existing defaults and allowing the middleware to be disabled. Refs bytedance/deer-flow#2517 * feat(loop-detection): add per-tool tool_freq_overrides to Phase 1 Adds ToolFreqOverride model and tool_freq_overrides field to LoopDetectionConfig, wires it through LoopDetectionMiddleware, and documents the option in config.example.yaml. Resolves the gap flagged in the #2586 review: without per-tool overrides, users hit by #2510/#2511 (RNA-seq workflows exceeding the bash hard limit) had no way to raise thresholds for one tool without loosening the global limit for every tool. Co-Authored-By: Claude Opus 4.7 * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * docs(loop-detection): document tool_freq_overrides in LoopDetectionMiddleware docstring Add the missing Args entry for tool_freq_overrides, explaining the (warn, hard_limit) tuple structure and how per-tool thresholds supersede the global tool_freq_warn / tool_freq_hard_limit for named tools. Also run ruff format on the three files flagged by the lint check. Co-Authored-By: Claude Sonnet 4.6 * fix(loop-detection): validate LoopDetectionMiddleware __init__ params eagerly Raise clear ValueError at construction time instead of crashing at unpack-time inside _track_and_check when bad values are passed: - tool_freq_overrides: must be 2-tuples of positive ints with hard_limit >= warn - scalar thresholds: warn_threshold, hard_limit, tool_freq_warn, tool_freq_hard_limit must be >= 1 and hard limits must >= their warn pairs - window_size, max_tracked_threads must be >= 1 Co-Authored-By: Claude Sonnet 4.6 * fix(test): isolate credential loader directory-path test from real ~/.claude The test didn't monkeypatch HOME, so on any machine with real Claude Code credentials at ~/.claude/.credentials.json the function fell through to those credentials and the assertion failed. Adding HOME redirect ensures the default credential path doesn't exist during the test. Co-Authored-By: Claude Sonnet 4.6 * style(test): add blank lines after import pytest in TestInitValidation Co-Authored-By: Claude Sonnet 4.6 * refactor(loop-detection): collapse dual validation to LoopDetectionConfig Modifications - LoopDetectionMiddleware.__init__: stripped of all ValueError raises; becomes a plain field-assignment constructor. - LoopDetectionMiddleware.from_config: classmethod that builds the middleware from a Pydantic-validated LoopDetectionConfig and handles the ToolFreqOverride -> tuple[int, int] conversion. - agents/factory.py: SDK construction routed through LoopDetectionMiddleware.from_config(LoopDetectionConfig()) so the defaults path is Pydantic-validated too. - agents/lead_agent/agent.py: uses from_config instead of unpacking config fields by hand. - tests/test_loop_detection_middleware.py: deleted TestInitValidation (16 methods exercising the removed __init__ checks); added TestFromConfig (4 tests: scalar field mapping, override tuple conversion, empty overrides, behavioral smoke test). Result: one validation layer (Pydantic), zero duplication, no __new__ hacks. Both production construction sites flow through LoopDetectionConfig. Test results make test -> 2977 passed, 18 skipped, 0 failed (137s) make format -> All checks passed; 411 files left unchanged * feat(agents): make loop_detection configurable in create_deerflow_agent Adds a `loop_detection: bool | AgentMiddleware = True` field to RuntimeFeatures, mirroring the existing pattern used by `sandbox`, `memory`, and `vision`. SDK users can now disable LoopDetectionMiddleware or replace it with a custom instance built from their own LoopDetectionConfig — e.g. `LoopDetectionMiddleware.from_config(my_cfg)` — instead of being stuck with the hardcoded defaults previously installed by the SDK factory. The lead-agent path (which already reads AppConfig.loop_detection) is unchanged, and the default `True` preserves prior always-on behavior for all existing callers. Co-Authored-By: Claude Opus 4.7 --------- Co-authored-by: knight0940 <631532668@qq.com> Co-authored-by: Claude Opus 4.7 Co-authored-by: Amorend <142649913+knight0940@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Willem Jiang --- .../harness/deerflow/agents/factory.py | 13 +++- .../harness/deerflow/agents/features.py | 1 + .../deerflow/agents/lead_agent/agent.py | 4 +- .../middlewares/loop_detection_middleware.py | 44 +++++++++-- .../harness/deerflow/config/__init__.py | 2 + .../harness/deerflow/config/app_config.py | 2 + .../deerflow/config/loop_detection_config.py | 73 ++++++++++++++++++ backend/tests/test_create_deerflow_agent.py | 46 +++++++++++ backend/tests/test_credential_loader.py | 2 + .../tests/test_lead_agent_model_resolution.py | 58 +++++++++++++- backend/tests/test_loop_detection_config.py | 72 ++++++++++++++++++ .../tests/test_loop_detection_middleware.py | 76 +++++++++++++++++++ config.example.yaml | 25 +++++- 13 files changed, 406 insertions(+), 12 deletions(-) create mode 100644 backend/packages/harness/deerflow/config/loop_detection_config.py create mode 100644 backend/tests/test_loop_detection_config.py diff --git a/backend/packages/harness/deerflow/agents/factory.py b/backend/packages/harness/deerflow/agents/factory.py index bd57d733d..e847023b9 100644 --- a/backend/packages/harness/deerflow/agents/factory.py +++ b/backend/packages/harness/deerflow/agents/factory.py @@ -173,7 +173,7 @@ def _assemble_from_features( 9. MemoryMiddleware (memory feature) 10. ViewImageMiddleware (vision feature) 11. SubagentLimitMiddleware (subagent feature) - 12. LoopDetectionMiddleware (always) + 12. LoopDetectionMiddleware (loop_detection feature) 13. ClarificationMiddleware (always last) Two-phase ordering: @@ -272,10 +272,15 @@ def _assemble_from_features( extra_tools.append(task_tool) - # --- [12] LoopDetection (always) --- - from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware + # --- [12] LoopDetection --- + if feat.loop_detection is not False: + if isinstance(feat.loop_detection, AgentMiddleware): + chain.append(feat.loop_detection) + else: + from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware + from deerflow.config.loop_detection_config import LoopDetectionConfig - chain.append(LoopDetectionMiddleware()) + chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig())) # --- [13] Clarification (always last among built-ins) --- chain.append(ClarificationMiddleware()) diff --git a/backend/packages/harness/deerflow/agents/features.py b/backend/packages/harness/deerflow/agents/features.py index 0fc485a3d..db2bc9580 100644 --- a/backend/packages/harness/deerflow/agents/features.py +++ b/backend/packages/harness/deerflow/agents/features.py @@ -31,6 +31,7 @@ class RuntimeFeatures: vision: bool | AgentMiddleware = False auto_title: bool | AgentMiddleware = False guardrail: Literal[False] | AgentMiddleware = False + loop_detection: bool | AgentMiddleware = True # --------------------------------------------------------------------------- diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index 5704b47cf..7540cdcdc 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -299,7 +299,9 @@ def _build_middlewares( middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents)) # LoopDetectionMiddleware — detect and break repetitive tool call loops - middlewares.append(LoopDetectionMiddleware()) + loop_detection_config = resolved_app_config.loop_detection + if loop_detection_config.enabled: + middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config)) # Inject custom middlewares before ClarificationMiddleware if custom_middlewares: diff --git a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py index 0c17d95c4..db83051e9 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py @@ -12,18 +12,23 @@ Detection strategy: response so the agent is forced to produce a final text answer. """ +from __future__ import annotations + import hashlib import json import logging import threading from collections import OrderedDict, defaultdict from copy import deepcopy -from typing import override +from typing import TYPE_CHECKING, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime +if TYPE_CHECKING: + from deerflow.config.loop_detection_config import LoopDetectionConfig + logger = logging.getLogger(__name__) # Defaults — can be overridden via constructor @@ -139,6 +144,9 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times class LoopDetectionMiddleware(AgentMiddleware[AgentState]): """Detects and breaks repetitive tool call loops. + Threshold parameters are validated upstream by :class:`LoopDetectionConfig`; + construct via :meth:`from_config` to ensure values pass Pydantic validation. + Args: warn_threshold: Number of identical tool call sets before injecting a warning message. Default: 3. @@ -154,6 +162,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): Default: 30. tool_freq_hard_limit: Number of calls to the same tool type before forcing a stop. Default: 50. + tool_freq_overrides: Per-tool overrides for frequency thresholds, + keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple + that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for + that specific tool. Tools not listed here fall back to the global + thresholds. Useful for raising limits on intentionally + high-frequency tools (e.g. ``bash`` in batch pipelines) without + weakening protection on all other tools. Default: ``None`` + (no overrides). """ def __init__( @@ -164,6 +180,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS, tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN, tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT, + tool_freq_overrides: dict[str, tuple[int, int]] | None = None, ): super().__init__() self.warn_threshold = warn_threshold @@ -172,14 +189,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): self.max_tracked_threads = max_tracked_threads self.tool_freq_warn = tool_freq_warn self.tool_freq_hard_limit = tool_freq_hard_limit + self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {} self._lock = threading.Lock() - # Per-thread tracking using OrderedDict for LRU eviction self._history: OrderedDict[str, list[str]] = OrderedDict() self._warned: dict[str, set[str]] = defaultdict(set) - # Per-thread, per-tool-type cumulative call counts self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) + @classmethod + def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware: + """Construct from a Pydantic-validated config, trusting its validation.""" + return cls( + warn_threshold=config.warn_threshold, + hard_limit=config.hard_limit, + window_size=config.window_size, + max_tracked_threads=config.max_tracked_threads, + tool_freq_warn=config.tool_freq_warn, + tool_freq_hard_limit=config.tool_freq_hard_limit, + tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()}, + ) + def _get_thread_id(self, runtime: Runtime) -> str: """Extract thread_id from runtime context for per-thread tracking.""" thread_id = runtime.context.get("thread_id") if runtime.context else None @@ -279,7 +308,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): freq[name] += 1 tc_count = freq[name] - if tc_count >= self.tool_freq_hard_limit: + if name in self._tool_freq_overrides: + eff_warn, eff_hard = self._tool_freq_overrides[name] + else: + eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit + + if tc_count >= eff_hard: logger.error( "Tool frequency hard limit reached — forcing stop", extra={ @@ -290,7 +324,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): ) return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True - if tc_count >= self.tool_freq_warn: + if tc_count >= eff_warn: warned = self._tool_freq_warned[thread_id] if name not in warned: warned.add(name) diff --git a/backend/packages/harness/deerflow/config/__init__.py b/backend/packages/harness/deerflow/config/__init__.py index 2e1ee82f8..bf74dc5e8 100644 --- a/backend/packages/harness/deerflow/config/__init__.py +++ b/backend/packages/harness/deerflow/config/__init__.py @@ -1,5 +1,6 @@ from .app_config import get_app_config from .extensions_config import ExtensionsConfig, get_extensions_config +from .loop_detection_config import LoopDetectionConfig from .memory_config import MemoryConfig, get_memory_config from .paths import Paths, get_paths from .skill_evolution_config import SkillEvolutionConfig @@ -20,6 +21,7 @@ __all__ = [ "SkillsConfig", "ExtensionsConfig", "get_extensions_config", + "LoopDetectionConfig", "MemoryConfig", "get_memory_config", "get_tracing_config", diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index 7fe218919..d470d6558 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -15,6 +15,7 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo from deerflow.config.database_config import DatabaseConfig from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict +from deerflow.config.loop_detection_config import LoopDetectionConfig from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict from deerflow.config.model_config import ModelConfig from deerflow.config.run_events_config import RunEventsConfig @@ -100,6 +101,7 @@ class AppConfig(BaseModel): subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") + loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration") model_config = ConfigDict(extra="allow") database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration") run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration") diff --git a/backend/packages/harness/deerflow/config/loop_detection_config.py b/backend/packages/harness/deerflow/config/loop_detection_config.py new file mode 100644 index 000000000..3eb9a2f37 --- /dev/null +++ b/backend/packages/harness/deerflow/config/loop_detection_config.py @@ -0,0 +1,73 @@ +"""Configuration for loop detection middleware.""" + +from pydantic import BaseModel, Field, model_validator + + +class ToolFreqOverride(BaseModel): + """Per-tool frequency threshold override. + + Can be higher or lower than the global defaults. Commonly used to raise + thresholds for high-frequency tools like bash in batch workflows (e.g. + RNA-seq pipelines) without weakening protection on every other tool. + """ + + warn: int = Field(ge=1) + hard_limit: int = Field(ge=1) + + @model_validator(mode="after") + def _validate(self) -> "ToolFreqOverride": + if self.hard_limit < self.warn: + raise ValueError("hard_limit must be >= warn") + return self + + +class LoopDetectionConfig(BaseModel): + """Configuration for repetitive tool-call loop detection.""" + + enabled: bool = Field( + default=True, + description="Whether to enable repetitive tool-call loop detection", + ) + warn_threshold: int = Field( + default=3, + ge=1, + description="Number of identical tool-call sets before injecting a warning", + ) + hard_limit: int = Field( + default=5, + ge=1, + description="Number of identical tool-call sets before forcing a stop", + ) + window_size: int = Field( + default=20, + ge=1, + description="Number of recent tool-call sets to track per thread", + ) + max_tracked_threads: int = Field( + default=100, + ge=1, + description="Maximum number of thread histories to keep in memory", + ) + tool_freq_warn: int = Field( + default=30, + ge=1, + description="Number of calls to the same tool type before injecting a frequency warning", + ) + tool_freq_hard_limit: int = Field( + default=50, + ge=1, + description="Number of calls to the same tool type before forcing a stop", + ) + tool_freq_overrides: dict[str, ToolFreqOverride] = Field( + default_factory=dict, + description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."), + ) + + @model_validator(mode="after") + def validate_thresholds(self) -> "LoopDetectionConfig": + """Ensure hard stop cannot happen before the warning threshold.""" + if self.hard_limit < self.warn_threshold: + raise ValueError("hard_limit must be greater than or equal to warn_threshold") + if self.tool_freq_hard_limit < self.tool_freq_warn: + raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn") + return self diff --git a/backend/tests/test_create_deerflow_agent.py b/backend/tests/test_create_deerflow_agent.py index fb403ed7f..d0fc6c9a8 100644 --- a/backend/tests/test_create_deerflow_agent.py +++ b/backend/tests/test_create_deerflow_agent.py @@ -192,6 +192,7 @@ def test_agent_features_defaults(): assert f.vision is False assert f.auto_title is False assert f.guardrail is False + assert f.loop_detection is True # --------------------------------------------------------------------------- @@ -630,6 +631,51 @@ def test_loop_detection_before_clarification(mock_create_agent): assert loop_idx == clar_idx - 1 +# --------------------------------------------------------------------------- +# 30b. loop_detection=False skips LoopDetectionMiddleware +# --------------------------------------------------------------------------- +@patch("deerflow.agents.factory.create_agent") +def test_loop_detection_disabled(mock_create_agent): + mock_create_agent.return_value = MagicMock() + create_deerflow_agent( + _make_mock_model(), + features=RuntimeFeatures(sandbox=False, loop_detection=False), + ) + + call_kwargs = mock_create_agent.call_args[1] + mw_types = [type(m).__name__ for m in call_kwargs["middleware"]] + assert "LoopDetectionMiddleware" not in mw_types + + +# --------------------------------------------------------------------------- +# 30c. loop_detection= replaces the default +# --------------------------------------------------------------------------- +@patch("deerflow.agents.factory.create_agent") +def test_loop_detection_custom_middleware(mock_create_agent): + from langchain.agents.middleware import AgentMiddleware as AM + + mock_create_agent.return_value = MagicMock() + + class MyLoopDetection(AM): + pass + + custom = MyLoopDetection() + create_deerflow_agent( + _make_mock_model(), + features=RuntimeFeatures(sandbox=False, loop_detection=custom), + ) + + call_kwargs = mock_create_agent.call_args[1] + middleware = call_kwargs["middleware"] + assert custom in middleware + mw_types = [type(m).__name__ for m in middleware] + # Default LoopDetectionMiddleware must not also appear. + assert "LoopDetectionMiddleware" not in mw_types + # Custom replacement still sits immediately before ClarificationMiddleware. + assert mw_types[-1] == "ClarificationMiddleware" + assert mw_types[-2] == "MyLoopDetection" + + # --------------------------------------------------------------------------- # 31. plan_mode=True adds TodoMiddleware # --------------------------------------------------------------------------- diff --git a/backend/tests/test_credential_loader.py b/backend/tests/test_credential_loader.py index 3c2bb1d94..c56ac6af3 100644 --- a/backend/tests/test_credential_loader.py +++ b/backend/tests/test_credential_loader.py @@ -85,6 +85,8 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch): def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch): _clear_claude_code_env(monkeypatch) + # Redirect HOME so the default ~/.claude/.credentials.json doesn't exist + monkeypatch.setenv("HOME", str(tmp_path)) cred_dir = tmp_path / "claude-creds-dir" cred_dir.mkdir() monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir)) diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index b240116cd..976730d44 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -8,17 +8,20 @@ from unittest.mock import MagicMock import pytest from deerflow.agents.lead_agent import agent as lead_agent_module +from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.config.app_config import AppConfig +from deerflow.config.loop_detection_config import LoopDetectionConfig from deerflow.config.memory_config import MemoryConfig from deerflow.config.model_config import ModelConfig from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.summarization_config import SummarizationConfig -def _make_app_config(models: list[ModelConfig]) -> AppConfig: +def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig: return AppConfig( models=models, sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"), + loop_detection=loop_detection or LoopDetectionConfig(), ) @@ -340,6 +343,59 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa assert middlewares[0] == "base-middleware" +def test_build_middlewares_uses_loop_detection_config(monkeypatch): + app_config = _make_app_config( + [_make_model("safe-model", supports_thinking=False)], + loop_detection=LoopDetectionConfig( + warn_threshold=7, + hard_limit=9, + window_size=30, + max_tracked_threads=40, + tool_freq_warn=50, + tool_freq_hard_limit=60, + ), + ) + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: []) + monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None) + monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None) + + middlewares = lead_agent_module._build_middlewares( + {"configurable": {"is_plan_mode": False, "subagent_enabled": False}}, + model_name="safe-model", + app_config=app_config, + ) + + loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware)) + assert loop_detection.warn_threshold == 7 + assert loop_detection.hard_limit == 9 + assert loop_detection.window_size == 30 + assert loop_detection.max_tracked_threads == 40 + assert loop_detection.tool_freq_warn == 50 + assert loop_detection.tool_freq_hard_limit == 60 + + +def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch): + app_config = _make_app_config( + [_make_model("safe-model", supports_thinking=False)], + loop_detection=LoopDetectionConfig(enabled=False), + ) + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: []) + monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None) + monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None) + + middlewares = lead_agent_module._build_middlewares( + {"configurable": {"is_plan_mode": False, "subagent_enabled": False}}, + model_name="safe-model", + app_config=app_config, + ) + + assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares) + + def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch): app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)]) app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork") diff --git a/backend/tests/test_loop_detection_config.py b/backend/tests/test_loop_detection_config.py new file mode 100644 index 000000000..93bc6ac3e --- /dev/null +++ b/backend/tests/test_loop_detection_config.py @@ -0,0 +1,72 @@ +"""Tests for loop detection configuration.""" + +import pytest + +from deerflow.config.loop_detection_config import LoopDetectionConfig + + +class TestLoopDetectionConfig: + def test_defaults_match_middleware_defaults(self): + config = LoopDetectionConfig() + + assert config.enabled is True + assert config.warn_threshold == 3 + assert config.hard_limit == 5 + assert config.window_size == 20 + assert config.max_tracked_threads == 100 + assert config.tool_freq_warn == 30 + assert config.tool_freq_hard_limit == 50 + + def test_accepts_custom_values(self): + config = LoopDetectionConfig( + enabled=False, + warn_threshold=10, + hard_limit=20, + window_size=50, + max_tracked_threads=200, + tool_freq_warn=60, + tool_freq_hard_limit=80, + ) + + assert config.enabled is False + assert config.warn_threshold == 10 + assert config.hard_limit == 20 + assert config.window_size == 50 + assert config.max_tracked_threads == 200 + assert config.tool_freq_warn == 60 + assert config.tool_freq_hard_limit == 80 + + def test_rejects_zero_thresholds(self): + with pytest.raises(ValueError): + LoopDetectionConfig(warn_threshold=0) + + with pytest.raises(ValueError): + LoopDetectionConfig(hard_limit=0) + + with pytest.raises(ValueError): + LoopDetectionConfig(tool_freq_warn=0) + + with pytest.raises(ValueError): + LoopDetectionConfig(tool_freq_hard_limit=0) + + def test_rejects_hard_limit_below_warn_threshold(self): + with pytest.raises(ValueError, match="hard_limit"): + LoopDetectionConfig(warn_threshold=5, hard_limit=4) + + def test_rejects_tool_freq_hard_limit_below_warn_threshold(self): + with pytest.raises(ValueError, match="tool_freq_hard_limit"): + LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4) + + def test_tool_freq_override_valid(self): + config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}}) + override = config.tool_freq_overrides["bash"] + assert override.warn == 150 + assert override.hard_limit == 300 + + def test_tool_freq_override_rejects_zero_warn(self): + with pytest.raises(ValueError): + LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}}) + + def test_tool_freq_override_rejects_hard_limit_below_warn(self): + with pytest.raises(ValueError, match="hard_limit"): + LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}}) diff --git a/backend/tests/test_loop_detection_middleware.py b/backend/tests/test_loop_detection_middleware.py index 878a433dd..022afc117 100644 --- a/backend/tests/test_loop_detection_middleware.py +++ b/backend/tests/test_loop_detection_middleware.py @@ -648,6 +648,37 @@ class TestToolFrequencyDetection: assert result is not None assert "read_file" in result["messages"][0].content + def test_override_tool_uses_override_thresholds(self): + """A tool in tool_freq_overrides uses its own thresholds, not the global ones.""" + mw = LoopDetectionMiddleware( + tool_freq_warn=5, + tool_freq_hard_limit=10, + tool_freq_overrides={"bash": (50, 100)}, + ) + runtime = _make_runtime() + + # 10 bash calls — would hit global hard_limit=10, but bash override is 100 + for i in range(10): + result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime) + assert result is None, f"unexpected trigger on call {i + 1}" + + def test_non_override_tool_falls_back_to_global(self): + """A tool NOT in tool_freq_overrides uses the global warn/hard_limit.""" + mw = LoopDetectionMiddleware( + tool_freq_warn=3, + tool_freq_hard_limit=6, + tool_freq_overrides={"bash": (50, 100)}, + ) + runtime = _make_runtime() + + for i in range(2): + mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) + + # 3rd read_file call hits global warn=3 (read_file has no override) + result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) + assert result is not None + assert "read_file" in result["messages"][0].content + def test_hash_detection_takes_priority(self): """Hash-based hard stop fires before frequency check for identical calls.""" mw = LoopDetectionMiddleware( @@ -668,3 +699,48 @@ class TestToolFrequencyDetection: msg = result["messages"][0] assert isinstance(msg, AIMessage) assert _HARD_STOP_MSG in msg.content + + +class TestFromConfig: + """Tests for LoopDetectionMiddleware.from_config — the sole validated construction path.""" + + @staticmethod + def _config(**kwargs): + from deerflow.config.loop_detection_config import LoopDetectionConfig + + return LoopDetectionConfig(**kwargs) + + def test_scalar_fields_mapped(self): + config = self._config( + warn_threshold=4, + hard_limit=8, + window_size=15, + max_tracked_threads=50, + tool_freq_warn=20, + tool_freq_hard_limit=40, + ) + mw = LoopDetectionMiddleware.from_config(config) + assert mw.warn_threshold == 4 + assert mw.hard_limit == 8 + assert mw.window_size == 15 + assert mw.max_tracked_threads == 50 + assert mw.tool_freq_warn == 20 + assert mw.tool_freq_hard_limit == 40 + + def test_overrides_converted_to_tuples(self): + config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}}) + mw = LoopDetectionMiddleware.from_config(config) + assert mw._tool_freq_overrides == {"bash": (50, 100)} + + def test_empty_overrides(self): + mw = LoopDetectionMiddleware.from_config(self._config()) + assert mw._tool_freq_overrides == {} + + def test_constructed_middleware_detects_loops(self): + mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4)) + runtime = _make_runtime() + call = [_bash_call("ls")] + mw._apply(_make_state(tool_calls=call), runtime) + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is not None + assert "LOOP DETECTED" in result["messages"][0].content diff --git a/config.example.yaml b/config.example.yaml index 7e282e46e..6f3fb1483 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -15,7 +15,7 @@ # ============================================================================ # Bump this number when the config schema changes. # Run `make config-upgrade` to merge new fields into your local config.yaml. -config_version: 8 +config_version: 9 # ============================================================================ # Logging @@ -506,6 +506,29 @@ tools: tool_search: enabled: false +# ============================================================================ +# Loop Detection Configuration +# ============================================================================ +# Detect and interrupt repeated identical tool-call loops. +# Frequency thresholds are safety limits for repeated use of the same tool type. + +loop_detection: + enabled: true + warn_threshold: 3 + hard_limit: 5 + window_size: 20 + max_tracked_threads: 100 + tool_freq_warn: 30 + tool_freq_hard_limit: 50 + # Per-tool overrides for tool_freq_warn / tool_freq_hard_limit. Values can be + # higher or lower than the global defaults. Commonly used to raise thresholds + # for high-frequency tools like bash in batch workflows (e.g. RNA-seq pipelines) + # without weakening protection on every other tool. + # tool_freq_overrides: + # bash: + # warn: 150 + # hard_limit: 300 + # ============================================================================ # Sandbox Configuration # ============================================================================