From dfa9fc47b3f762b4309c3c286077ba1c73be747b Mon Sep 17 00:00:00 2001 From: rayhpeng Date: Sun, 12 Apr 2026 13:37:08 +0800 Subject: [PATCH] feat(memory): thread user_id through memory updater layer Add `user_id` keyword-only parameter to all public updater functions (_save_memory_to_file, get_memory_data, reload_memory_data, import_memory_data, clear_memory_data, create/delete/update_memory_fact) and regular keyword param to MemoryUpdater.update_memory + update_memory_from_conversation, propagating it to every storage load/save/reload call. Co-Authored-By: Claude Sonnet 4.6 --- .../harness/deerflow/agents/memory/updater.py | 51 +++++++++++-------- backend/tests/test_memory_updater.py | 4 +- .../test_memory_updater_user_isolation.py | 29 +++++++++++ 3 files changed, 61 insertions(+), 23 deletions(-) create mode 100644 backend/tests/test_memory_updater_user_isolation.py diff --git a/backend/packages/harness/deerflow/agents/memory/updater.py b/backend/packages/harness/deerflow/agents/memory/updater.py index d1f124d4c..178c6bf62 100644 --- a/backend/packages/harness/deerflow/agents/memory/updater.py +++ b/backend/packages/harness/deerflow/agents/memory/updater.py @@ -27,27 +27,28 @@ def _create_empty_memory() -> dict[str, Any]: return create_empty_memory() -def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool: +def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: """Backward-compatible wrapper around the configured memory storage save path.""" - return get_memory_storage().save(memory_data, agent_name) + return get_memory_storage().save(memory_data, agent_name, user_id=user_id) -def get_memory_data(agent_name: str | None = None) -> dict[str, Any]: +def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Get the current memory data via storage provider.""" - return get_memory_storage().load(agent_name) + return get_memory_storage().load(agent_name, user_id=user_id) -def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]: +def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Reload memory data via storage provider.""" - return get_memory_storage().reload(agent_name) + return get_memory_storage().reload(agent_name, user_id=user_id) -def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]: +def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Persist imported memory data via storage provider. Args: memory_data: Full memory payload to persist. agent_name: If provided, imports into per-agent memory. + user_id: If provided, scopes memory to a specific user. Returns: The saved memory data after storage normalization. @@ -56,15 +57,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non OSError: If persisting the imported memory fails. """ storage = get_memory_storage() - if not storage.save(memory_data, agent_name): + if not storage.save(memory_data, agent_name, user_id=user_id): raise OSError("Failed to save imported memory data") - return storage.load(agent_name) + return storage.load(agent_name, user_id=user_id) -def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]: +def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Clear all stored memory data and persist an empty structure.""" cleared_memory = create_empty_memory() - if not _save_memory_to_file(cleared_memory, agent_name): + if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id): raise OSError("Failed to save cleared memory data") return cleared_memory @@ -81,6 +82,8 @@ def create_memory_fact( category: str = "context", confidence: float = 0.5, agent_name: str | None = None, + *, + user_id: str | None = None, ) -> dict[str, Any]: """Create a new fact and persist the updated memory data.""" normalized_content = content.strip() @@ -90,7 +93,7 @@ def create_memory_fact( normalized_category = category.strip() or "context" validated_confidence = _validate_confidence(confidence) now = utc_now_iso_z() - memory_data = get_memory_data(agent_name) + memory_data = get_memory_data(agent_name, user_id=user_id) updated_memory = dict(memory_data) facts = list(memory_data.get("facts", [])) facts.append( @@ -105,15 +108,15 @@ def create_memory_fact( ) updated_memory["facts"] = facts - if not _save_memory_to_file(updated_memory, agent_name): + if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): raise OSError("Failed to save memory data after creating fact") return updated_memory -def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]: +def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: """Delete a fact by its id and persist the updated memory data.""" - memory_data = get_memory_data(agent_name) + memory_data = get_memory_data(agent_name, user_id=user_id) facts = memory_data.get("facts", []) updated_facts = [fact for fact in facts if fact.get("id") != fact_id] if len(updated_facts) == len(facts): @@ -122,7 +125,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, updated_memory = dict(memory_data) updated_memory["facts"] = updated_facts - if not _save_memory_to_file(updated_memory, agent_name): + if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'") return updated_memory @@ -134,9 +137,11 @@ def update_memory_fact( category: str | None = None, confidence: float | None = None, agent_name: str | None = None, + *, + user_id: str | None = None, ) -> dict[str, Any]: """Update an existing fact and persist the updated memory data.""" - memory_data = get_memory_data(agent_name) + memory_data = get_memory_data(agent_name, user_id=user_id) updated_memory = dict(memory_data) updated_facts: list[dict[str, Any]] = [] found = False @@ -163,7 +168,7 @@ def update_memory_fact( updated_memory["facts"] = updated_facts - if not _save_memory_to_file(updated_memory, agent_name): + if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): raise OSError(f"Failed to save memory data after updating fact '{fact_id}'") return updated_memory @@ -276,6 +281,7 @@ class MemoryUpdater: agent_name: str | None = None, correction_detected: bool = False, reinforcement_detected: bool = False, + user_id: str | None = None, ) -> bool: """Update memory based on conversation messages. @@ -285,6 +291,7 @@ class MemoryUpdater: agent_name: If provided, updates per-agent memory. If None, updates global memory. correction_detected: Whether recent turns include an explicit correction signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal. + user_id: If provided, scopes memory to a specific user. Returns: True if update was successful, False otherwise. @@ -298,7 +305,7 @@ class MemoryUpdater: try: # Get current memory - current_memory = get_memory_data(agent_name) + current_memory = get_memory_data(agent_name, user_id=user_id) # Format conversation for prompt conversation_text = format_conversation_for_update(messages) @@ -353,7 +360,7 @@ class MemoryUpdater: updated_memory = _strip_upload_mentions_from_memory(updated_memory) # Save - return get_memory_storage().save(updated_memory, agent_name) + return get_memory_storage().save(updated_memory, agent_name, user_id=user_id) except json.JSONDecodeError as e: logger.warning("Failed to parse LLM response for memory update: %s", e) @@ -455,6 +462,7 @@ def update_memory_from_conversation( agent_name: str | None = None, correction_detected: bool = False, reinforcement_detected: bool = False, + user_id: str | None = None, ) -> bool: """Convenience function to update memory from a conversation. @@ -464,9 +472,10 @@ def update_memory_from_conversation( agent_name: If provided, updates per-agent memory. If None, updates global memory. correction_detected: Whether recent turns include an explicit correction signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal. + user_id: If provided, scopes memory to a specific user. Returns: True if successful, False otherwise. """ updater = MemoryUpdater() - return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected) + return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id) diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index 48fdfd89e..995b652ec 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None: with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): result = import_memory_data(imported_memory) - mock_storage.save.assert_called_once_with(imported_memory, None) - mock_storage.load.assert_called_once_with(None) + mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None) + mock_storage.load.assert_called_once_with(None, user_id=None) assert result == imported_memory diff --git a/backend/tests/test_memory_updater_user_isolation.py b/backend/tests/test_memory_updater_user_isolation.py new file mode 100644 index 000000000..d38f3fc90 --- /dev/null +++ b/backend/tests/test_memory_updater_user_isolation.py @@ -0,0 +1,29 @@ +"""Tests for user_id propagation in memory updater.""" +from unittest.mock import MagicMock, patch + +from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file + + +def test_get_memory_data_passes_user_id(): + mock_storage = MagicMock() + mock_storage.load.return_value = {"version": "1.0"} + with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): + get_memory_data(user_id="alice") + mock_storage.load.assert_called_once_with(None, user_id="alice") + + +def test_save_memory_passes_user_id(): + mock_storage = MagicMock() + mock_storage.save.return_value = True + with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): + _save_memory_to_file({"version": "1.0"}, user_id="bob") + mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob") + + +def test_clear_memory_data_passes_user_id(): + mock_storage = MagicMock() + mock_storage.save.return_value = True + with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): + clear_memory_data(user_id="charlie") + # Verify save was called with user_id + assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"