From 8049785de666ddeb143693df21e15d76582acba8 Mon Sep 17 00:00:00 2001 From: thefoolgy <99605054+thefoolgy@users.noreply.github.com> Date: Sun, 5 Apr 2026 16:23:00 +0800 Subject: [PATCH] fix(memory): case-insensitive fact deduplication and positive reinforcement detection (#1804) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(memory): case-insensitive fact deduplication and positive reinforcement detection Two fixes to the memory system: 1. _fact_content_key() now lowercases content before comparison, preventing semantically duplicate facts like "User prefers Python" and "user prefers python" from being stored separately. 2. Adds detect_reinforcement() to MemoryMiddleware (closes #1719), mirroring detect_correction(). When users signal approval ("yes exactly", "perfect", "完全正确", etc.), the memory updater now receives reinforcement_detected=True and injects a hint prompting the LLM to record confirmed preferences and behaviors with high confidence. Changes across the full signal path: - memory_middleware.py: _REINFORCEMENT_PATTERNS + detect_reinforcement() - queue.py: reinforcement_detected field in ConversationContext and add() - updater.py: reinforcement_detected param in update_memory() and update_memory_from_conversation(); builds reinforcement_hint alongside the existing correction_hint Tests: 11 new tests covering deduplication, hint injection, and signal detection (Chinese + English patterns, window boundary, conflict with correction). Co-Authored-By: Claude Sonnet 4.6 * fix(memory): address Copilot review comments on reinforcement detection - Tighten _REINFORCEMENT_PATTERNS: remove 很好, require punctuation/end-of-string boundaries on remaining patterns, split this-is-good into stricter variants - Suppress reinforcement_detected when correction_detected is true to avoid mixed-signal noise - Use casefold() instead of lower() for Unicode-aware fact deduplication - Add missing test coverage for reinforcement_detected OR merge and forwarding in queue --------- Co-authored-by: Claude Sonnet 4.6 --- .../harness/deerflow/agents/memory/queue.py | 6 + .../harness/deerflow/agents/memory/updater.py | 16 +- .../agents/middlewares/memory_middleware.py | 41 +++++ backend/tests/test_memory_queue.py | 41 +++++ backend/tests/test_memory_updater.py | 153 ++++++++++++++++++ backend/tests/test_memory_upload_filtering.py | 72 ++++++++- 6 files changed, 326 insertions(+), 3 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index 6d777a67e..d78c643f8 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -21,6 +21,7 @@ class ConversationContext: timestamp: datetime = field(default_factory=datetime.utcnow) agent_name: str | None = None correction_detected: bool = False + reinforcement_detected: bool = False class MemoryUpdateQueue: @@ -44,6 +45,7 @@ class MemoryUpdateQueue: messages: list[Any], agent_name: str | None = None, correction_detected: bool = False, + reinforcement_detected: bool = False, ) -> None: """Add a conversation to the update queue. @@ -52,6 +54,7 @@ class MemoryUpdateQueue: messages: The conversation messages. agent_name: If provided, memory is stored per-agent. If None, uses global memory. correction_detected: Whether recent turns include an explicit correction signal. + reinforcement_detected: Whether recent turns include a positive reinforcement signal. """ config = get_memory_config() if not config.enabled: @@ -63,11 +66,13 @@ class MemoryUpdateQueue: None, ) merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False) + merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False) context = ConversationContext( thread_id=thread_id, messages=messages, agent_name=agent_name, correction_detected=merged_correction_detected, + reinforcement_detected=merged_reinforcement_detected, ) # Check if this thread already has a pending update @@ -130,6 +135,7 @@ class MemoryUpdateQueue: thread_id=context.thread_id, agent_name=context.agent_name, correction_detected=context.correction_detected, + reinforcement_detected=context.reinforcement_detected, ) if success: logger.info("Memory updated successfully for thread %s", context.thread_id) diff --git a/backend/packages/harness/deerflow/agents/memory/updater.py b/backend/packages/harness/deerflow/agents/memory/updater.py index c59749d7b..5f459b47a 100644 --- a/backend/packages/harness/deerflow/agents/memory/updater.py +++ b/backend/packages/harness/deerflow/agents/memory/updater.py @@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None: stripped = content.strip() if not stripped: return None - return stripped + return stripped.casefold() class MemoryUpdater: @@ -272,6 +272,7 @@ class MemoryUpdater: thread_id: str | None = None, agent_name: str | None = None, correction_detected: bool = False, + reinforcement_detected: bool = False, ) -> bool: """Update memory based on conversation messages. @@ -280,6 +281,7 @@ class MemoryUpdater: thread_id: Optional thread ID for tracking source. agent_name: If provided, updates per-agent memory. If None, updates global memory. correction_detected: Whether recent turns include an explicit correction signal. + reinforcement_detected: Whether recent turns include a positive reinforcement signal. Returns: True if update was successful, False otherwise. @@ -310,6 +312,14 @@ class MemoryUpdater: "and record the correct approach as a fact with category " '"correction" and confidence >= 0.95 when appropriate.' ) + if reinforcement_detected: + reinforcement_hint = ( + "IMPORTANT: Positive reinforcement signals were detected in this conversation. " + "The user explicitly confirmed the agent's approach was correct or helpful. " + "Record the confirmed approach, style, or preference as a fact with category " + '"preference" or "behavior" and confidence >= 0.9 when appropriate.' + ) + correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint prompt = MEMORY_UPDATE_PROMPT.format( current_memory=json.dumps(current_memory, indent=2), @@ -441,6 +451,7 @@ def update_memory_from_conversation( thread_id: str | None = None, agent_name: str | None = None, correction_detected: bool = False, + reinforcement_detected: bool = False, ) -> bool: """Convenience function to update memory from a conversation. @@ -449,9 +460,10 @@ def update_memory_from_conversation( thread_id: Optional thread ID. agent_name: If provided, updates per-agent memory. If None, updates global memory. correction_detected: Whether recent turns include an explicit correction signal. + reinforcement_detected: Whether recent turns include a positive reinforcement signal. Returns: True if successful, False otherwise. """ updater = MemoryUpdater() - return updater.update_memory(messages, thread_id, agent_name, correction_detected) + return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected) diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py index 6215a2957..5e8ca6344 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py @@ -29,6 +29,22 @@ _CORRECTION_PATTERNS = ( re.compile(r"改用"), ) +_REINFORCEMENT_PATTERNS = ( + re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE), + re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE), + re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE), + re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE), + re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE), + re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE), + re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE), + re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE), + re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"), + re.compile(r"完全正确(?:[。!?!?.]|$)"), + re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"), + re.compile(r"正是我想要的(?:[。!?!?.]|$)"), + re.compile(r"继续保持(?:[。!?!?.]|$)"), +) + class MemoryMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" @@ -132,6 +148,29 @@ def detect_correction(messages: list[Any]) -> bool: return False +def detect_reinforcement(messages: list[Any]) -> bool: + """Detect explicit positive reinforcement signals in recent conversation turns. + + Complements detect_correction() by identifying when the user confirms the + agent's approach was correct. This allows the memory system to record what + worked well, not just what went wrong. + + The queue keeps only one pending context per thread, so callers pass the + latest filtered message list. Checking only recent user turns keeps signal + detection conservative while avoiding stale signals from long histories. + """ + recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"] + + for msg in recent_user_msgs: + content = _extract_message_text(msg).strip() + if not content: + continue + if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS): + return True + + return False + + class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): """Middleware that queues conversation for memory update after agent execution. @@ -196,12 +235,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): # Queue the filtered conversation for memory update correction_detected = detect_correction(filtered_messages) + reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) queue = get_memory_queue() queue.add( thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name, correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, ) return None diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py index 6ef91a142..204f9d16e 100644 --- a/backend/tests/test_memory_queue.py +++ b/backend/tests/test_memory_queue.py @@ -47,4 +47,45 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None: thread_id="thread-1", agent_name="lead_agent", correction_detected=True, + reinforcement_detected=False, + ) + + +def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True) + queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False) + + assert len(queue._queue) == 1 + assert queue._queue[0].messages == ["second"] + assert queue._queue[0].reinforcement_detected is True + + +def test_process_queue_forwards_reinforcement_flag_to_updater() -> None: + queue = MemoryUpdateQueue() + queue._queue = [ + ConversationContext( + thread_id="thread-1", + messages=["conversation"], + agent_name="lead_agent", + reinforcement_detected=True, + ) + ] + mock_updater = MagicMock() + mock_updater.update_memory.return_value = True + + with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater): + queue._process_queue() + + mock_updater.update_memory.assert_called_once_with( + messages=["conversation"], + thread_id="thread-1", + agent_name="lead_agent", + correction_detected=False, + reinforcement_detected=True, ) diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index 6309cf9f6..48fdfd89e 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -619,3 +619,156 @@ class TestUpdateMemoryStructuredResponse: assert result is True prompt = model.invoke.call_args[0][0] assert "Explicit correction signals were detected" not in prompt + + +class TestFactDeduplicationCaseInsensitive: + """Tests that fact deduplication is case-insensitive.""" + + def test_duplicate_fact_different_case_not_stored(self): + updater = MemoryUpdater() + current_memory = _make_memory( + facts=[ + { + "id": "fact_1", + "content": "User prefers Python", + "category": "preference", + "confidence": 0.9, + "createdAt": "2026-01-01T00:00:00Z", + "source": "thread-a", + }, + ] + ) + # Same fact with different casing should be treated as duplicate + update_data = { + "factsToRemove": [], + "newFacts": [ + {"content": "user prefers python", "category": "preference", "confidence": 0.95}, + ], + } + + with patch( + "deerflow.agents.memory.updater.get_memory_config", + return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7), + ): + result = updater._apply_updates(current_memory, update_data, thread_id="thread-b") + + # Should still have only 1 fact (duplicate rejected) + assert len(result["facts"]) == 1 + assert result["facts"][0]["content"] == "User prefers Python" + + def test_unique_fact_different_case_and_content_stored(self): + updater = MemoryUpdater() + current_memory = _make_memory( + facts=[ + { + "id": "fact_1", + "content": "User prefers Python", + "category": "preference", + "confidence": 0.9, + "createdAt": "2026-01-01T00:00:00Z", + "source": "thread-a", + }, + ] + ) + update_data = { + "factsToRemove": [], + "newFacts": [ + {"content": "User prefers Go", "category": "preference", "confidence": 0.85}, + ], + } + + with patch( + "deerflow.agents.memory.updater.get_memory_config", + return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7), + ): + result = updater._apply_updates(current_memory, update_data, thread_id="thread-b") + + assert len(result["facts"]) == 2 + + +class TestReinforcementHint: + """Tests that reinforcement_detected injects the correct hint into the prompt.""" + + @staticmethod + def _make_mock_model(json_response: str): + model = MagicMock() + response = MagicMock() + response.content = f"```json\n{json_response}\n```" + model.invoke.return_value = response + return model + + def test_reinforcement_hint_injected_when_detected(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Yes, exactly! That's what I needed." + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Great to hear!" + ai_msg.tool_calls = [] + + result = updater.update_memory([msg, ai_msg], reinforcement_detected=True) + + assert result is True + prompt = model.invoke.call_args[0][0] + assert "Positive reinforcement signals were detected" in prompt + + def test_reinforcement_hint_absent_when_not_detected(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Tell me more." + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Sure." + ai_msg.tool_calls = [] + + result = updater.update_memory([msg, ai_msg], reinforcement_detected=False) + + assert result is True + prompt = model.invoke.call_args[0][0] + assert "Positive reinforcement signals were detected" not in prompt + + def test_both_hints_present_when_both_detected(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "No wait, that's wrong. Actually yes, exactly right." + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Got it." + ai_msg.tool_calls = [] + + result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True) + + assert result is True + prompt = model.invoke.call_args[0][0] + assert "Explicit correction signals were detected" in prompt + assert "Positive reinforcement signals were detected" in prompt diff --git a/backend/tests/test_memory_upload_filtering.py b/backend/tests/test_memory_upload_filtering.py index 1ff0aa3b6..2e2308b61 100644 --- a/backend/tests/test_memory_upload_filtering.py +++ b/backend/tests/test_memory_upload_filtering.py @@ -10,7 +10,7 @@ persisting in long-term memory: from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory -from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction +from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement # --------------------------------------------------------------------------- # Helpers @@ -270,3 +270,73 @@ class TestStripUploadMentionsFromMemory: mem = {"user": {}, "history": {}, "facts": []} result = _strip_upload_mentions_from_memory(mem) assert result == {"user": {}, "history": {}, "facts": []} + + +# =========================================================================== +# detect_reinforcement +# =========================================================================== + + +class TestDetectReinforcement: + def test_detects_english_reinforcement_signal(self): + msgs = [ + _human("Can you summarise it in bullet points?"), + _ai("Here are the key points: ..."), + _human("Yes, exactly! That's what I needed."), + _ai("Glad it helped."), + ] + + assert detect_reinforcement(msgs) is True + + def test_detects_perfect_signal(self): + msgs = [ + _human("Write it more concisely."), + _ai("Here is the concise version."), + _human("Perfect."), + _ai("Great!"), + ] + + assert detect_reinforcement(msgs) is True + + def test_detects_chinese_reinforcement_signal(self): + msgs = [ + _human("帮我用要点来总结"), + _ai("好的,要点如下:..."), + _human("完全正确,就是这个意思"), + _ai("很高兴能帮到你"), + ] + + assert detect_reinforcement(msgs) is True + + def test_returns_false_without_signal(self): + msgs = [ + _human("What does this function do?"), + _ai("It processes the input data."), + _human("Can you show me an example?"), + ] + + assert detect_reinforcement(msgs) is False + + def test_only_checks_recent_messages(self): + # Reinforcement signal buried beyond the -6 window should not trigger + msgs = [ + _human("Yes, exactly right."), + _ai("Noted."), + _human("Let's discuss tests."), + _ai("Sure."), + _human("What about linting?"), + _ai("Use ruff."), + _human("And formatting?"), + _ai("Use make format."), + ] + + assert detect_reinforcement(msgs) is False + + def test_does_not_conflict_with_correction(self): + # A message can trigger correction but not reinforcement + msgs = [ + _human("That's wrong, try again."), + _ai("Corrected."), + ] + + assert detect_reinforcement(msgs) is False