fix: Memory update system has cache corruption, data loss, and thread-safety bugs (#2251)

* fix(memory): cache corruption, thread-safety, and caller mutation bugs

Bug 1 (updater.py): deep-copy current_memory before passing to
_apply_updates() so a subsequent save() failure cannot leave a
partially-mutated object in the storage cache.

Bug 3 (storage.py): add _cache_lock (threading.Lock) to
FileMemoryStorage and acquire it around every read/write of
_memory_cache, fixing concurrent-access races between the background
timer thread and HTTP reload calls.

Bug 4 (storage.py): replace in-place mutation
  memory_data["lastUpdated"] = ...
with a shallow copy
  memory_data = {**memory_data, "lastUpdated": ...}
so save() no longer silently modifies the caller's dict.

Regression tests added for all three bugs in test_memory_storage.py
and test_memory_updater.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: format test_memory_updater.py with ruff

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: remove stale bug-number labels from code comments and docstrings

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
DanielWalnut 2026-04-17 12:00:31 +08:00 committed by GitHub
parent 259a6844bf
commit 898f4e8ac2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 159 additions and 9 deletions

View File

@ -67,6 +67,8 @@ class FileMemoryStorage(MemoryStorage):
# Per-agent memory cache: keyed by agent_name (None = global) # Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime) # Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {} self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
def _validate_agent_name(self, agent_name: str) -> None: def _validate_agent_name(self, agent_name: str) -> None:
"""Validate that the agent name is safe to use in filesystem paths. """Validate that the agent name is safe to use in filesystem paths.
@ -115,14 +117,17 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
current_mtime = None current_mtime = None
cached = self._memory_cache.get(agent_name) with self._cache_lock:
cached = self._memory_cache.get(agent_name)
if cached is not None and cached[1] == current_mtime:
return cached[0]
if cached is None or cached[1] != current_mtime: memory_data = self._load_memory_from_file(agent_name)
memory_data = self._load_memory_from_file(agent_name)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, current_mtime) self._memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data
return cached[0] return memory_data
def reload(self, agent_name: str | None = None) -> dict[str, Any]: def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation.""" """Reload memory data from file, forcing cache invalidation."""
@ -134,7 +139,8 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
mtime = None mtime = None
self._memory_cache[agent_name] = (memory_data, mtime) with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
return memory_data return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool: def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
@ -143,7 +149,10 @@ class FileMemoryStorage(MemoryStorage):
try: try:
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
memory_data["lastUpdated"] = utc_now_iso_z() # Shallow-copy before adding lastUpdated so the caller's dict is not
# mutated as a side-effect, and the cache reference is not silently
# updated before the file write succeeds.
memory_data = {**memory_data, "lastUpdated": utc_now_iso_z()}
temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp") temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp")
with open(temp_path, "w", encoding="utf-8") as f: with open(temp_path, "w", encoding="utf-8") as f:
@ -156,7 +165,8 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
mtime = None mtime = None
self._memory_cache[agent_name] = (memory_data, mtime) with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path) logger.info("Memory saved to %s", file_path)
return True return True
except OSError as e: except OSError as e:

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import atexit import atexit
import concurrent.futures import concurrent.futures
import copy
import json import json
import logging import logging
import math import math
@ -380,7 +381,9 @@ class MemoryUpdater:
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text) update_data = json.loads(response_text)
updated_memory = self._apply_updates(current_memory, update_data, thread_id) # Deep-copy before in-place mutation so a subsequent save() failure
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory) updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name) return get_memory_storage().save(updated_memory, agent_name)

View File

@ -110,6 +110,93 @@ class TestFileMemoryStorage:
assert result is True assert result is True
assert memory_file.exists() assert memory_file.exists()
def test_save_does_not_mutate_caller_dict(self, tmp_path):
"""save() must not mutate the caller's dict (lastUpdated side-effect)."""
memory_file = tmp_path / "memory.json"
def mock_get_paths():
mock_paths = MagicMock()
mock_paths.memory_file = memory_file
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
original = {"version": "1.0", "facts": []}
before_keys = set(original.keys())
storage.save(original)
assert set(original.keys()) == before_keys, "save() must not add keys to caller's dict"
assert "lastUpdated" not in original
def test_cache_not_corrupted_when_save_fails(self, tmp_path):
"""Cache must remain clean when save() raises OSError.
If save() fails, the cache must NOT be updated with the new data.
Together with the deepcopy in updater._finalize_update(), this prevents
stale mutations from leaking into the cache when persistence fails.
"""
memory_file = tmp_path / "memory.json"
memory_file.parent.mkdir(parents=True, exist_ok=True)
original_data = {"version": "1.0", "facts": [{"content": "original"}]}
import json as _json
memory_file.write_text(_json.dumps(original_data))
def mock_get_paths():
mock_paths = MagicMock()
mock_paths.memory_file = memory_file
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
# Warm the cache
cached = storage.load()
assert cached["facts"][0]["content"] == "original"
# Simulate save failure: mkdir succeeds but open() raises
modified = {"version": "1.0", "facts": [{"content": "mutated"}]}
with patch("builtins.open", side_effect=OSError("disk full")):
result = storage.save(modified)
assert result is False
# Cache must still reflect the original data, not the failed write
after = storage.load()
assert after["facts"][0]["content"] == "original"
def test_cache_thread_safety(self, tmp_path):
"""Concurrent load/reload calls must not race on _memory_cache."""
memory_file = tmp_path / "memory.json"
memory_file.parent.mkdir(parents=True, exist_ok=True)
import json as _json
memory_file.write_text(_json.dumps({"version": "1.0", "facts": []}))
def mock_get_paths():
mock_paths = MagicMock()
mock_paths.memory_file = memory_file
return mock_paths
errors: list[Exception] = []
def load_many(storage: FileMemoryStorage) -> None:
try:
for _ in range(50):
storage.load()
except Exception as exc:
errors.append(exc)
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
threads = [threading.Thread(target=load_many, args=(storage,)) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors, f"Thread-safety errors: {errors}"
def test_reload_forces_cache_invalidation(self, tmp_path): def test_reload_forces_cache_invalidation(self, tmp_path):
"""Should force reload from file and invalidate cache.""" """Should force reload from file and invalidate cache."""
memory_file = tmp_path / "memory.json" memory_file = tmp_path / "memory.json"

View File

@ -881,3 +881,53 @@ class TestReinforcementHint:
prompt = model.ainvoke.await_args.args[0] prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" in prompt assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt assert "Positive reinforcement signals were detected" in prompt
class TestFinalizeCacheIsolation:
"""_finalize_update must not mutate the cached memory object."""
def test_deepcopy_prevents_cache_corruption_on_save_failure(self):
"""If save() fails, the in-memory snapshot used by _finalize_update
must remain independent of any object the storage layer may still hold in
its cache. The deepcopy in _finalize_update achieves this the object
passed to _apply_updates is always a fresh copy, never the cache reference.
"""
updater = MemoryUpdater()
original_memory = _make_memory(facts=[{"id": "fact_orig", "content": "original", "category": "context", "confidence": 0.9, "createdAt": "2024-01-01T00:00:00Z", "source": "t1"}])
import json as _json
new_fact_json = _json.dumps(
{
"user": {},
"history": {},
"newFacts": [{"content": "new fact", "category": "context", "confidence": 0.9}],
"factsToRemove": [],
}
)
mock_response = MagicMock()
mock_response.content = new_fact_json
mock_model = AsyncMock()
mock_model.ainvoke = AsyncMock(return_value=mock_response)
saved_objects: list[dict] = []
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
with (
patch.object(updater, "_get_model", return_value=mock_model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=original_memory),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=save_mock)),
):
msg = MagicMock()
msg.type = "human"
msg.content = "hello"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "world"
ai_msg.tool_calls = []
updater.update_memory([msg, ai_msg], thread_id="t1")
# original_memory must not have been mutated — deepcopy isolates the mutation
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
assert original_memory["facts"][0]["content"] == "original"