mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix: Memory update system has cache corruption, data loss, and thread-safety bugs (#2251)
* fix(memory): cache corruption, thread-safety, and caller mutation bugs
Bug 1 (updater.py): deep-copy current_memory before passing to
_apply_updates() so a subsequent save() failure cannot leave a
partially-mutated object in the storage cache.
Bug 3 (storage.py): add _cache_lock (threading.Lock) to
FileMemoryStorage and acquire it around every read/write of
_memory_cache, fixing concurrent-access races between the background
timer thread and HTTP reload calls.
Bug 4 (storage.py): replace in-place mutation
memory_data["lastUpdated"] = ...
with a shallow copy
memory_data = {**memory_data, "lastUpdated": ...}
so save() no longer silently modifies the caller's dict.
Regression tests added for all three bugs in test_memory_storage.py
and test_memory_updater.py.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* style: format test_memory_updater.py with ruff
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* style: remove stale bug-number labels from code comments and docstrings
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
259a6844bf
commit
898f4e8ac2
@ -67,6 +67,8 @@ class FileMemoryStorage(MemoryStorage):
|
||||
# Per-agent memory cache: keyed by agent_name (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||
# Guards all reads and writes to _memory_cache across concurrent callers.
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
def _validate_agent_name(self, agent_name: str) -> None:
|
||||
"""Validate that the agent name is safe to use in filesystem paths.
|
||||
@ -115,14 +117,17 @@ class FileMemoryStorage(MemoryStorage):
|
||||
except OSError:
|
||||
current_mtime = None
|
||||
|
||||
cached = self._memory_cache.get(agent_name)
|
||||
with self._cache_lock:
|
||||
cached = self._memory_cache.get(agent_name)
|
||||
if cached is not None and cached[1] == current_mtime:
|
||||
return cached[0]
|
||||
|
||||
if cached is None or cached[1] != current_mtime:
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, current_mtime)
|
||||
return memory_data
|
||||
|
||||
return cached[0]
|
||||
return memory_data
|
||||
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data from file, forcing cache invalidation."""
|
||||
@ -134,7 +139,8 @@ class FileMemoryStorage(MemoryStorage):
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
return memory_data
|
||||
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
@ -143,7 +149,10 @@ class FileMemoryStorage(MemoryStorage):
|
||||
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
memory_data["lastUpdated"] = utc_now_iso_z()
|
||||
# Shallow-copy before adding lastUpdated so the caller's dict is not
|
||||
# mutated as a side-effect, and the cache reference is not silently
|
||||
# updated before the file write succeeds.
|
||||
memory_data = {**memory_data, "lastUpdated": utc_now_iso_z()}
|
||||
|
||||
temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp")
|
||||
with open(temp_path, "w", encoding="utf-8") as f:
|
||||
@ -156,7 +165,8 @@ class FileMemoryStorage(MemoryStorage):
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
logger.info("Memory saved to %s", file_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@ -380,7 +381,9 @@ class MemoryUpdater:
|
||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
update_data = json.loads(response_text)
|
||||
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||
# Deep-copy before in-place mutation so a subsequent save() failure
|
||||
# cannot corrupt the still-cached original object reference.
|
||||
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
return get_memory_storage().save(updated_memory, agent_name)
|
||||
|
||||
|
||||
@ -110,6 +110,93 @@ class TestFileMemoryStorage:
|
||||
assert result is True
|
||||
assert memory_file.exists()
|
||||
|
||||
def test_save_does_not_mutate_caller_dict(self, tmp_path):
|
||||
"""save() must not mutate the caller's dict (lastUpdated side-effect)."""
|
||||
memory_file = tmp_path / "memory.json"
|
||||
|
||||
def mock_get_paths():
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.memory_file = memory_file
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
original = {"version": "1.0", "facts": []}
|
||||
before_keys = set(original.keys())
|
||||
storage.save(original)
|
||||
assert set(original.keys()) == before_keys, "save() must not add keys to caller's dict"
|
||||
assert "lastUpdated" not in original
|
||||
|
||||
def test_cache_not_corrupted_when_save_fails(self, tmp_path):
|
||||
"""Cache must remain clean when save() raises OSError.
|
||||
|
||||
If save() fails, the cache must NOT be updated with the new data.
|
||||
Together with the deepcopy in updater._finalize_update(), this prevents
|
||||
stale mutations from leaking into the cache when persistence fails.
|
||||
"""
|
||||
memory_file = tmp_path / "memory.json"
|
||||
memory_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
original_data = {"version": "1.0", "facts": [{"content": "original"}]}
|
||||
import json as _json
|
||||
|
||||
memory_file.write_text(_json.dumps(original_data))
|
||||
|
||||
def mock_get_paths():
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.memory_file = memory_file
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
# Warm the cache
|
||||
cached = storage.load()
|
||||
assert cached["facts"][0]["content"] == "original"
|
||||
|
||||
# Simulate save failure: mkdir succeeds but open() raises
|
||||
modified = {"version": "1.0", "facts": [{"content": "mutated"}]}
|
||||
with patch("builtins.open", side_effect=OSError("disk full")):
|
||||
result = storage.save(modified)
|
||||
assert result is False
|
||||
|
||||
# Cache must still reflect the original data, not the failed write
|
||||
after = storage.load()
|
||||
assert after["facts"][0]["content"] == "original"
|
||||
|
||||
def test_cache_thread_safety(self, tmp_path):
|
||||
"""Concurrent load/reload calls must not race on _memory_cache."""
|
||||
memory_file = tmp_path / "memory.json"
|
||||
memory_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
import json as _json
|
||||
|
||||
memory_file.write_text(_json.dumps({"version": "1.0", "facts": []}))
|
||||
|
||||
def mock_get_paths():
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.memory_file = memory_file
|
||||
return mock_paths
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def load_many(storage: FileMemoryStorage) -> None:
|
||||
try:
|
||||
for _ in range(50):
|
||||
storage.load()
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
threads = [threading.Thread(target=load_many, args=(storage,)) for _ in range(8)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors, f"Thread-safety errors: {errors}"
|
||||
|
||||
def test_reload_forces_cache_invalidation(self, tmp_path):
|
||||
"""Should force reload from file and invalidate cache."""
|
||||
memory_file = tmp_path / "memory.json"
|
||||
|
||||
@ -881,3 +881,53 @@ class TestReinforcementHint:
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
|
||||
class TestFinalizeCacheIsolation:
|
||||
"""_finalize_update must not mutate the cached memory object."""
|
||||
|
||||
def test_deepcopy_prevents_cache_corruption_on_save_failure(self):
|
||||
"""If save() fails, the in-memory snapshot used by _finalize_update
|
||||
must remain independent of any object the storage layer may still hold in
|
||||
its cache. The deepcopy in _finalize_update achieves this — the object
|
||||
passed to _apply_updates is always a fresh copy, never the cache reference.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
original_memory = _make_memory(facts=[{"id": "fact_orig", "content": "original", "category": "context", "confidence": 0.9, "createdAt": "2024-01-01T00:00:00Z", "source": "t1"}])
|
||||
|
||||
import json as _json
|
||||
|
||||
new_fact_json = _json.dumps(
|
||||
{
|
||||
"user": {},
|
||||
"history": {},
|
||||
"newFacts": [{"content": "new fact", "category": "context", "confidence": 0.9}],
|
||||
"factsToRemove": [],
|
||||
}
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = new_fact_json
|
||||
mock_model = AsyncMock()
|
||||
mock_model.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
saved_objects: list[dict] = []
|
||||
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=mock_model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=original_memory),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=save_mock)),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "world"
|
||||
ai_msg.tool_calls = []
|
||||
updater.update_memory([msg, ai_msg], thread_id="t1")
|
||||
|
||||
# original_memory must not have been mutated — deepcopy isolates the mutation
|
||||
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
|
||||
assert original_memory["facts"][0]["content"] == "original"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user