feat(memory): thread user_id through memory updater layer

Add `user_id` keyword-only parameter to all public updater functions
(_save_memory_to_file, get_memory_data, reload_memory_data, import_memory_data,
clear_memory_data, create/delete/update_memory_fact) and regular keyword param
to MemoryUpdater.update_memory + update_memory_from_conversation, propagating
it to every storage load/save/reload call.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-12 13:37:08 +08:00
parent 3877aabcfd
commit dfa9fc47b3
3 changed files with 61 additions and 23 deletions

View File

@ -27,27 +27,28 @@ def _create_empty_memory() -> dict[str, Any]:
return create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Backward-compatible wrapper around the configured memory storage save path."""
return get_memory_storage().save(memory_data, agent_name)
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Get the current memory data via storage provider."""
return get_memory_storage().load(agent_name)
return get_memory_storage().load(agent_name, user_id=user_id)
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Reload memory data via storage provider."""
return get_memory_storage().reload(agent_name)
return get_memory_storage().reload(agent_name, user_id=user_id)
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider.
Args:
memory_data: Full memory payload to persist.
agent_name: If provided, imports into per-agent memory.
user_id: If provided, scopes memory to a specific user.
Returns:
The saved memory data after storage normalization.
@ -56,15 +57,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
OSError: If persisting the imported memory fails.
"""
storage = get_memory_storage()
if not storage.save(memory_data, agent_name):
if not storage.save(memory_data, agent_name, user_id=user_id):
raise OSError("Failed to save imported memory data")
return storage.load(agent_name)
return storage.load(agent_name, user_id=user_id)
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Clear all stored memory data and persist an empty structure."""
cleared_memory = create_empty_memory()
if not _save_memory_to_file(cleared_memory, agent_name):
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
raise OSError("Failed to save cleared memory data")
return cleared_memory
@ -81,6 +82,8 @@ def create_memory_fact(
category: str = "context",
confidence: float = 0.5,
agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]:
"""Create a new fact and persist the updated memory data."""
normalized_content = content.strip()
@ -90,7 +93,7 @@ def create_memory_fact(
normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence)
now = utc_now_iso_z()
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(agent_name, user_id=user_id)
updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", []))
facts.append(
@ -105,15 +108,15 @@ def create_memory_fact(
)
updated_memory["facts"] = facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError("Failed to save memory data after creating fact")
return updated_memory
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Delete a fact by its id and persist the updated memory data."""
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(agent_name, user_id=user_id)
facts = memory_data.get("facts", [])
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
if len(updated_facts) == len(facts):
@ -122,7 +125,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str,
updated_memory = dict(memory_data)
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
return updated_memory
@ -134,9 +137,11 @@ def update_memory_fact(
category: str | None = None,
confidence: float | None = None,
agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]:
"""Update an existing fact and persist the updated memory data."""
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(agent_name, user_id=user_id)
updated_memory = dict(memory_data)
updated_facts: list[dict[str, Any]] = []
found = False
@ -163,7 +168,7 @@ def update_memory_fact(
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
return updated_memory
@ -276,6 +281,7 @@ class MemoryUpdater:
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Update memory based on conversation messages.
@ -285,6 +291,7 @@ class MemoryUpdater:
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns:
True if update was successful, False otherwise.
@ -298,7 +305,7 @@ class MemoryUpdater:
try:
# Get current memory
current_memory = get_memory_data(agent_name)
current_memory = get_memory_data(agent_name, user_id=user_id)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
@ -353,7 +360,7 @@ class MemoryUpdater:
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage().save(updated_memory, agent_name)
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
@ -455,6 +462,7 @@ def update_memory_from_conversation(
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Convenience function to update memory from a conversation.
@ -464,9 +472,10 @@ def update_memory_from_conversation(
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)

View File

@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
result = import_memory_data(imported_memory)
mock_storage.save.assert_called_once_with(imported_memory, None)
mock_storage.load.assert_called_once_with(None)
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
mock_storage.load.assert_called_once_with(None, user_id=None)
assert result == imported_memory

View File

@ -0,0 +1,29 @@
"""Tests for user_id propagation in memory updater."""
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
def test_get_memory_data_passes_user_id():
mock_storage = MagicMock()
mock_storage.load.return_value = {"version": "1.0"}
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
get_memory_data(user_id="alice")
mock_storage.load.assert_called_once_with(None, user_id="alice")
def test_save_memory_passes_user_id():
mock_storage = MagicMock()
mock_storage.save.return_value = True
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
_save_memory_to_file({"version": "1.0"}, user_id="bob")
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
def test_clear_memory_data_passes_user_id():
mock_storage = MagicMock()
mock_storage.save.return_value = True
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
clear_memory_data(user_id="charlie")
# Verify save was called with user_id
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"