mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-26 03:38:06 +00:00
* fix: inject longTermBackground into memory prompt
The format_memory_for_injection function only processed recentMonths and
earlierContext from the history section, silently dropping longTermBackground.
The LLM writes longTermBackground correctly and it persists to memory.json,
but it was never injected into the system prompt — making the user's
long-term background invisible to the AI.
Add the missing field handling and a regression test.
* fix(middleware): handle list-type AIMessage.content in LoopDetectionMiddleware
LangChain AIMessage.content can be str | list. When using providers that
return structured content blocks (e.g. Anthropic thinking mode, certain
OpenAI-compatible gateways), content is a list of dicts like
[{"type": "text", "text": "..."}].
The hard_limit branch in _apply() concatenated content with a string via
(last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}", which raises
TypeError when content is a non-empty list (list + str is invalid).
Add _append_text() static method that:
- Returns the text directly when content is None
- Appends a {"type": "text"} block when content is a list
- Falls back to string concatenation when content is a str
This is consistent with how other modules in the project already handle
list content (client.py._extract_text, memory_middleware, executor.py).
* test(middleware): add unit tests for _append_text and list content hard stop
Add regression tests to verify LoopDetectionMiddleware handles list-type
AIMessage.content correctly during hard stop:
- TestAppendText: unit tests for the new _append_text() static method
covering None, str, list (including empty list) content types
- TestHardStopWithListContent: integration tests verifying hard stop
works correctly with list content (Anthropic thinking mode), None
content, and str content
Requested by reviewer in PR #1823.
* fix(middleware): improve _append_text robustness and test isolation
- Add explicit isinstance(content, str) check with fallback for
unexpected types (coerce to str) to prevent TypeError on edge cases
- Deep-copy list content in _make_state() test helper to prevent
shared mutable references across test iterations
- Add test_unexpected_type_coerced_to_str: verify fallback for
non-str/list/None content types
- Add test_list_content_not_mutated_in_place: verify _append_text
does not modify the original list
* style: fix ruff format whitespace in test file
---------
Co-authored-by: ppyt <14163465+ppyt@users.noreply.github.com>
349 lines
13 KiB
Python
349 lines
13 KiB
Python
"""Tests for LoopDetectionMiddleware."""
|
|
|
|
import copy
|
|
from unittest.mock import MagicMock
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
|
|
from deerflow.agents.middlewares.loop_detection_middleware import (
|
|
_HARD_STOP_MSG,
|
|
LoopDetectionMiddleware,
|
|
_hash_tool_calls,
|
|
)
|
|
|
|
|
|
def _make_runtime(thread_id="test-thread"):
|
|
"""Build a minimal Runtime mock with context."""
|
|
runtime = MagicMock()
|
|
runtime.context = {"thread_id": thread_id}
|
|
return runtime
|
|
|
|
|
|
def _make_state(tool_calls=None, content=""):
|
|
"""Build a minimal AgentState dict with an AIMessage.
|
|
|
|
Deep-copies *content* when it is mutable (e.g. list) so that
|
|
successive calls never share the same object reference.
|
|
"""
|
|
safe_content = copy.deepcopy(content) if isinstance(content, list) else content
|
|
msg = AIMessage(content=safe_content, tool_calls=tool_calls or [])
|
|
return {"messages": [msg]}
|
|
|
|
|
|
def _bash_call(cmd="ls"):
|
|
return {"name": "bash", "id": f"call_{cmd}", "args": {"command": cmd}}
|
|
|
|
|
|
class TestHashToolCalls:
|
|
def test_same_calls_same_hash(self):
|
|
a = _hash_tool_calls([_bash_call("ls")])
|
|
b = _hash_tool_calls([_bash_call("ls")])
|
|
assert a == b
|
|
|
|
def test_different_calls_different_hash(self):
|
|
a = _hash_tool_calls([_bash_call("ls")])
|
|
b = _hash_tool_calls([_bash_call("pwd")])
|
|
assert a != b
|
|
|
|
def test_order_independent(self):
|
|
a = _hash_tool_calls([_bash_call("ls"), {"name": "read_file", "args": {"path": "/tmp"}}])
|
|
b = _hash_tool_calls([{"name": "read_file", "args": {"path": "/tmp"}}, _bash_call("ls")])
|
|
assert a == b
|
|
|
|
def test_empty_calls(self):
|
|
h = _hash_tool_calls([])
|
|
assert isinstance(h, str)
|
|
assert len(h) > 0
|
|
|
|
|
|
class TestLoopDetection:
|
|
def test_no_tool_calls_returns_none(self):
|
|
mw = LoopDetectionMiddleware()
|
|
runtime = _make_runtime()
|
|
state = {"messages": [AIMessage(content="hello")]}
|
|
result = mw._apply(state, runtime)
|
|
assert result is None
|
|
|
|
def test_below_threshold_returns_none(self):
|
|
mw = LoopDetectionMiddleware(warn_threshold=3)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
# First two identical calls — no warning
|
|
for _ in range(2):
|
|
result = mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert result is None
|
|
|
|
def test_warn_at_threshold(self):
|
|
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
for _ in range(2):
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
|
|
# Third identical call triggers warning
|
|
result = mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert result is not None
|
|
msgs = result["messages"]
|
|
assert len(msgs) == 1
|
|
assert isinstance(msgs[0], HumanMessage)
|
|
assert "LOOP DETECTED" in msgs[0].content
|
|
|
|
def test_warn_only_injected_once(self):
|
|
"""Warning for the same hash should only be injected once per thread."""
|
|
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
# First two — no warning
|
|
for _ in range(2):
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
|
|
# Third — warning injected
|
|
result = mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert result is not None
|
|
assert "LOOP DETECTED" in result["messages"][0].content
|
|
|
|
# Fourth — warning already injected, should return None
|
|
result = mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert result is None
|
|
|
|
def test_hard_stop_at_limit(self):
|
|
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
for _ in range(3):
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
|
|
# Fourth call triggers hard stop
|
|
result = mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert result is not None
|
|
msgs = result["messages"]
|
|
assert len(msgs) == 1
|
|
# Hard stop strips tool_calls
|
|
assert isinstance(msgs[0], AIMessage)
|
|
assert msgs[0].tool_calls == []
|
|
assert _HARD_STOP_MSG in msgs[0].content
|
|
|
|
def test_different_calls_dont_trigger(self):
|
|
mw = LoopDetectionMiddleware(warn_threshold=2)
|
|
runtime = _make_runtime()
|
|
|
|
# Each call is different
|
|
for i in range(10):
|
|
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
|
|
assert result is None
|
|
|
|
def test_window_sliding(self):
|
|
mw = LoopDetectionMiddleware(warn_threshold=3, window_size=5)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
# Fill with 2 identical calls
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
|
|
# Push them out of the window with different calls
|
|
for i in range(5):
|
|
mw._apply(_make_state(tool_calls=[_bash_call(f"other_{i}")]), runtime)
|
|
|
|
# Now the original call should be fresh again — no warning
|
|
result = mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert result is None
|
|
|
|
def test_reset_clears_state(self):
|
|
mw = LoopDetectionMiddleware(warn_threshold=2)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
|
|
# Would trigger warning, but reset first
|
|
mw.reset()
|
|
result = mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert result is None
|
|
|
|
def test_non_ai_message_ignored(self):
|
|
mw = LoopDetectionMiddleware()
|
|
runtime = _make_runtime()
|
|
state = {"messages": [SystemMessage(content="hello")]}
|
|
result = mw._apply(state, runtime)
|
|
assert result is None
|
|
|
|
def test_empty_messages_ignored(self):
|
|
mw = LoopDetectionMiddleware()
|
|
runtime = _make_runtime()
|
|
result = mw._apply({"messages": []}, runtime)
|
|
assert result is None
|
|
|
|
def test_thread_id_from_runtime_context(self):
|
|
"""Thread ID should come from runtime.context, not state."""
|
|
mw = LoopDetectionMiddleware(warn_threshold=2)
|
|
runtime_a = _make_runtime("thread-A")
|
|
runtime_b = _make_runtime("thread-B")
|
|
call = [_bash_call("ls")]
|
|
|
|
# One call on thread A
|
|
mw._apply(_make_state(tool_calls=call), runtime_a)
|
|
# One call on thread B
|
|
mw._apply(_make_state(tool_calls=call), runtime_b)
|
|
|
|
# Second call on thread A — triggers warning (2 >= warn_threshold)
|
|
result = mw._apply(_make_state(tool_calls=call), runtime_a)
|
|
assert result is not None
|
|
assert "LOOP DETECTED" in result["messages"][0].content
|
|
|
|
# Second call on thread B — also triggers (independent tracking)
|
|
result = mw._apply(_make_state(tool_calls=call), runtime_b)
|
|
assert result is not None
|
|
assert "LOOP DETECTED" in result["messages"][0].content
|
|
|
|
def test_lru_eviction(self):
|
|
"""Old threads should be evicted when max_tracked_threads is exceeded."""
|
|
mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=3)
|
|
call = [_bash_call("ls")]
|
|
|
|
# Fill up 3 threads
|
|
for i in range(3):
|
|
runtime = _make_runtime(f"thread-{i}")
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
|
|
# Add a 4th thread — should evict thread-0
|
|
runtime_new = _make_runtime("thread-new")
|
|
mw._apply(_make_state(tool_calls=call), runtime_new)
|
|
|
|
assert "thread-0" not in mw._history
|
|
assert "thread-new" in mw._history
|
|
assert len(mw._history) == 3
|
|
|
|
def test_thread_safe_mutations(self):
|
|
"""Verify lock is used for mutations (basic structural test)."""
|
|
mw = LoopDetectionMiddleware()
|
|
# The middleware should have a lock attribute
|
|
assert hasattr(mw, "_lock")
|
|
assert isinstance(mw._lock, type(mw._lock))
|
|
|
|
def test_fallback_thread_id_when_missing(self):
|
|
"""When runtime context has no thread_id, should use 'default'."""
|
|
mw = LoopDetectionMiddleware(warn_threshold=2)
|
|
runtime = MagicMock()
|
|
runtime.context = {}
|
|
call = [_bash_call("ls")]
|
|
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert "default" in mw._history
|
|
|
|
|
|
class TestAppendText:
|
|
"""Unit tests for LoopDetectionMiddleware._append_text."""
|
|
|
|
def test_none_content_returns_text(self):
|
|
result = LoopDetectionMiddleware._append_text(None, "hello")
|
|
assert result == "hello"
|
|
|
|
def test_str_content_concatenates(self):
|
|
result = LoopDetectionMiddleware._append_text("existing", "appended")
|
|
assert result == "existing\n\nappended"
|
|
|
|
def test_empty_str_content_concatenates(self):
|
|
result = LoopDetectionMiddleware._append_text("", "appended")
|
|
assert result == "\n\nappended"
|
|
|
|
def test_list_content_appends_text_block(self):
|
|
"""List content (e.g. Anthropic thinking mode) should get a new text block."""
|
|
content = [
|
|
{"type": "thinking", "text": "Let me think..."},
|
|
{"type": "text", "text": "Here is my answer"},
|
|
]
|
|
result = LoopDetectionMiddleware._append_text(content, "stop msg")
|
|
assert isinstance(result, list)
|
|
assert len(result) == 3
|
|
assert result[0] == content[0]
|
|
assert result[1] == content[1]
|
|
assert result[2] == {"type": "text", "text": "\n\nstop msg"}
|
|
|
|
def test_empty_list_content_appends_text_block(self):
|
|
result = LoopDetectionMiddleware._append_text([], "stop msg")
|
|
assert isinstance(result, list)
|
|
assert len(result) == 1
|
|
assert result[0] == {"type": "text", "text": "\n\nstop msg"}
|
|
|
|
def test_unexpected_type_coerced_to_str(self):
|
|
"""Unexpected content types should be coerced to str as a fallback."""
|
|
result = LoopDetectionMiddleware._append_text(42, "stop msg")
|
|
assert isinstance(result, str)
|
|
assert result == "42\n\nstop msg"
|
|
|
|
def test_list_content_not_mutated_in_place(self):
|
|
"""_append_text must not modify the original list."""
|
|
original = [{"type": "text", "text": "hello"}]
|
|
result = LoopDetectionMiddleware._append_text(original, "appended")
|
|
assert len(original) == 1 # original unchanged
|
|
assert len(result) == 2 # new list has the appended block
|
|
|
|
|
|
class TestHardStopWithListContent:
|
|
"""Regression tests: hard stop must not crash when AIMessage.content is a list."""
|
|
|
|
def test_hard_stop_with_list_content(self):
|
|
"""Hard stop on list content should not raise TypeError (regression)."""
|
|
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
# Build state with list content (e.g. Anthropic thinking mode)
|
|
list_content = [
|
|
{"type": "thinking", "text": "Let me think..."},
|
|
{"type": "text", "text": "I'll run ls"},
|
|
]
|
|
|
|
for _ in range(3):
|
|
mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
|
|
|
|
# Fourth call triggers hard stop — must not raise TypeError
|
|
result = mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
|
|
assert result is not None
|
|
msg = result["messages"][0]
|
|
assert isinstance(msg, AIMessage)
|
|
assert msg.tool_calls == []
|
|
# Content should remain a list with the stop message appended
|
|
assert isinstance(msg.content, list)
|
|
assert len(msg.content) == 3
|
|
assert msg.content[2]["type"] == "text"
|
|
assert _HARD_STOP_MSG in msg.content[2]["text"]
|
|
|
|
def test_hard_stop_with_none_content(self):
|
|
"""Hard stop on None content should produce a plain string."""
|
|
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
for _ in range(3):
|
|
mw._apply(_make_state(tool_calls=call), runtime)
|
|
|
|
# Fourth call with default empty-string content
|
|
result = mw._apply(_make_state(tool_calls=call), runtime)
|
|
assert result is not None
|
|
msg = result["messages"][0]
|
|
assert isinstance(msg.content, str)
|
|
assert _HARD_STOP_MSG in msg.content
|
|
|
|
def test_hard_stop_with_str_content(self):
|
|
"""Hard stop on str content should concatenate the stop message."""
|
|
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
|
runtime = _make_runtime()
|
|
call = [_bash_call("ls")]
|
|
|
|
for _ in range(3):
|
|
mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
|
|
|
|
result = mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
|
|
assert result is not None
|
|
msg = result["messages"][0]
|
|
assert isinstance(msg.content, str)
|
|
assert msg.content.startswith("thinking...")
|
|
assert _HARD_STOP_MSG in msg.content
|