fix(memory): case-insensitive fact deduplication and positive reinforcement detection (#1804)

* fix(memory): case-insensitive fact deduplication and positive reinforcement detection

Two fixes to the memory system:

1. _fact_content_key() now lowercases content before comparison, preventing
   semantically duplicate facts like "User prefers Python" and "user prefers
   python" from being stored separately.

2. Adds detect_reinforcement() to MemoryMiddleware (closes #1719), mirroring
   detect_correction(). When users signal approval ("yes exactly", "perfect",
   "完全正确", etc.), the memory updater now receives reinforcement_detected=True
   and injects a hint prompting the LLM to record confirmed preferences and
   behaviors with high confidence.

   Changes across the full signal path:
   - memory_middleware.py: _REINFORCEMENT_PATTERNS + detect_reinforcement()
   - queue.py: reinforcement_detected field in ConversationContext and add()
   - updater.py: reinforcement_detected param in update_memory() and
     update_memory_from_conversation(); builds reinforcement_hint alongside
     the existing correction_hint

Tests: 11 new tests covering deduplication, hint injection, and signal
detection (Chinese + English patterns, window boundary, conflict with correction).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(memory): address Copilot review comments on reinforcement detection

- Tighten _REINFORCEMENT_PATTERNS: remove 很好, require punctuation/end-of-string boundaries on remaining patterns, split this-is-good into stricter variants
- Suppress reinforcement_detected when correction_detected is true to avoid mixed-signal noise
- Use casefold() instead of lower() for Unicode-aware fact deduplication
- Add missing test coverage for reinforcement_detected OR merge and forwarding in queue

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
thefoolgy 2026-04-05 16:23:00 +08:00 committed by GitHub
parent 9ca68ffaaa
commit 8049785de6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 326 additions and 3 deletions

View File

@ -21,6 +21,7 @@ class ConversationContext:
timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None
correction_detected: bool = False
reinforcement_detected: bool = False
class MemoryUpdateQueue:
@ -44,6 +45,7 @@ class MemoryUpdateQueue:
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
@ -52,6 +54,7 @@ class MemoryUpdateQueue:
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
if not config.enabled:
@ -63,11 +66,13 @@ class MemoryUpdateQueue:
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
# Check if this thread already has a pending update
@ -130,6 +135,7 @@ class MemoryUpdateQueue:
thread_id=context.thread_id,
agent_name=context.agent_name,
correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)

View File

@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
stripped = content.strip()
if not stripped:
return None
return stripped
return stripped.casefold()
class MemoryUpdater:
@ -272,6 +272,7 @@ class MemoryUpdater:
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages.
@ -280,6 +281,7 @@ class MemoryUpdater:
thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if update was successful, False otherwise.
@ -310,6 +312,14 @@ class MemoryUpdater:
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
@ -441,6 +451,7 @@ def update_memory_from_conversation(
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Convenience function to update memory from a conversation.
@ -449,9 +460,10 @@ def update_memory_from_conversation(
thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)

View File

@ -29,6 +29,22 @@ _CORRECTION_PATTERNS = (
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@ -132,6 +148,29 @@ def detect_correction(messages: list[Any]) -> bool:
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@ -196,12 +235,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
return None

View File

@ -47,4 +47,45 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=True,
reinforcement_detected=False,
)
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
assert len(queue._queue) == 1
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].reinforcement_detected is True
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue._queue = [
ConversationContext(
thread_id="thread-1",
messages=["conversation"],
agent_name="lead_agent",
reinforcement_detected=True,
)
]
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
queue._process_queue()
mock_updater.update_memory.assert_called_once_with(
messages=["conversation"],
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=False,
reinforcement_detected=True,
)

View File

@ -619,3 +619,156 @@ class TestUpdateMemoryStructuredResponse:
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" not in prompt
class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive."""
def test_duplicate_fact_different_case_not_stored(self):
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_1",
"content": "User prefers Python",
"category": "preference",
"confidence": 0.9,
"createdAt": "2026-01-01T00:00:00Z",
"source": "thread-a",
},
]
)
# Same fact with different casing should be treated as duplicate
update_data = {
"factsToRemove": [],
"newFacts": [
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
# Should still have only 1 fact (duplicate rejected)
assert len(result["facts"]) == 1
assert result["facts"][0]["content"] == "User prefers Python"
def test_unique_fact_different_case_and_content_stored(self):
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_1",
"content": "User prefers Python",
"category": "preference",
"confidence": 0.9,
"createdAt": "2026-01-01T00:00:00Z",
"source": "thread-a",
},
]
)
update_data = {
"factsToRemove": [],
"newFacts": [
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
assert len(result["facts"]) == 2
class TestReinforcementHint:
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
@staticmethod
def _make_mock_model(json_response: str):
model = MagicMock()
response = MagicMock()
response.content = f"```json\n{json_response}\n```"
model.invoke.return_value = response
return model
def test_reinforcement_hint_injected_when_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Yes, exactly! That's what I needed."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Great to hear!"
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Tell me more."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Sure."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "No wait, that's wrong. Actually yes, exactly right."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Got it."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt

View File

@ -10,7 +10,7 @@ persisting in long-term memory:
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
# ---------------------------------------------------------------------------
# Helpers
@ -270,3 +270,73 @@ class TestStripUploadMentionsFromMemory:
mem = {"user": {}, "history": {}, "facts": []}
result = _strip_upload_mentions_from_memory(mem)
assert result == {"user": {}, "history": {}, "facts": []}
# ===========================================================================
# detect_reinforcement
# ===========================================================================
class TestDetectReinforcement:
def test_detects_english_reinforcement_signal(self):
msgs = [
_human("Can you summarise it in bullet points?"),
_ai("Here are the key points: ..."),
_human("Yes, exactly! That's what I needed."),
_ai("Glad it helped."),
]
assert detect_reinforcement(msgs) is True
def test_detects_perfect_signal(self):
msgs = [
_human("Write it more concisely."),
_ai("Here is the concise version."),
_human("Perfect."),
_ai("Great!"),
]
assert detect_reinforcement(msgs) is True
def test_detects_chinese_reinforcement_signal(self):
msgs = [
_human("帮我用要点来总结"),
_ai("好的,要点如下:..."),
_human("完全正确,就是这个意思"),
_ai("很高兴能帮到你"),
]
assert detect_reinforcement(msgs) is True
def test_returns_false_without_signal(self):
msgs = [
_human("What does this function do?"),
_ai("It processes the input data."),
_human("Can you show me an example?"),
]
assert detect_reinforcement(msgs) is False
def test_only_checks_recent_messages(self):
# Reinforcement signal buried beyond the -6 window should not trigger
msgs = [
_human("Yes, exactly right."),
_ai("Noted."),
_human("Let's discuss tests."),
_ai("Sure."),
_human("What about linting?"),
_ai("Use ruff."),
_human("And formatting?"),
_ai("Use make format."),
]
assert detect_reinforcement(msgs) is False
def test_does_not_conflict_with_correction(self):
# A message can trigger correction but not reinforcement
msgs = [
_human("That's wrong, try again."),
_ai("Corrected."),
]
assert detect_reinforcement(msgs) is False