DanielWalnut 898f4e8ac2
fix: Memory update system has cache corruption, data loss, and thread-safety bugs (#2251)
* fix(memory): cache corruption, thread-safety, and caller mutation bugs

Bug 1 (updater.py): deep-copy current_memory before passing to
_apply_updates() so a subsequent save() failure cannot leave a
partially-mutated object in the storage cache.

Bug 3 (storage.py): add _cache_lock (threading.Lock) to
FileMemoryStorage and acquire it around every read/write of
_memory_cache, fixing concurrent-access races between the background
timer thread and HTTP reload calls.

Bug 4 (storage.py): replace in-place mutation
  memory_data["lastUpdated"] = ...
with a shallow copy
  memory_data = {**memory_data, "lastUpdated": ...}
so save() no longer silently modifies the caller's dict.

Regression tests added for all three bugs in test_memory_storage.py
and test_memory_updater.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: format test_memory_updater.py with ruff

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: remove stale bug-number labels from code comments and docstrings

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-17 12:00:31 +08:00

217 lines
7.8 KiB
Python

"""Memory storage providers."""
import abc
import json
import logging
import threading
import uuid
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.memory_config import get_memory_config
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
def utc_now_iso_z() -> str:
"""Current UTC time as ISO-8601 with ``Z`` suffix (matches prior naive-UTC output)."""
return datetime.now(UTC).isoformat().removesuffix("+00:00") + "Z"
def create_empty_memory() -> dict[str, Any]:
"""Create an empty memory structure."""
return {
"version": "1.0",
"lastUpdated": utc_now_iso_z(),
"user": {
"workContext": {"summary": "", "updatedAt": ""},
"personalContext": {"summary": "", "updatedAt": ""},
"topOfMind": {"summary": "", "updatedAt": ""},
},
"history": {
"recentMonths": {"summary": "", "updatedAt": ""},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "", "updatedAt": ""},
},
"facts": [],
}
class MemoryStorage(abc.ABC):
"""Abstract base class for memory storage providers."""
@abc.abstractmethod
def load(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data for the given agent."""
pass
@abc.abstractmethod
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Force reload memory data for the given agent."""
pass
@abc.abstractmethod
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Save memory data for the given agent."""
pass
class FileMemoryStorage(MemoryStorage):
"""File-based memory storage provider."""
def __init__(self):
"""Initialize the file memory storage."""
# Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
def _validate_agent_name(self, agent_name: str) -> None:
"""Validate that the agent name is safe to use in filesystem paths.
Uses the repository's established AGENT_NAME_PATTERN to ensure consistency
across the codebase and prevent path traversal or other problematic characters.
"""
if not agent_name:
raise ValueError("Agent name must be a non-empty string.")
if not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
"""Get the path to the memory file."""
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
if config.storage_path:
p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p
return get_paths().memory_file
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data from file."""
file_path = self._get_memory_file_path(agent_name)
if not file_path.exists():
return create_empty_memory()
try:
with open(file_path, encoding="utf-8") as f:
data = json.load(f)
return data
except (json.JSONDecodeError, OSError) as e:
logger.warning("Failed to load memory file: %s", e)
return create_empty_memory()
def load(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data (cached with file modification time check)."""
file_path = self._get_memory_file_path(agent_name)
try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
current_mtime = None
with self._cache_lock:
cached = self._memory_cache.get(agent_name)
if cached is not None and cached[1] == current_mtime:
return cached[0]
memory_data = self._load_memory_from_file(agent_name)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation."""
file_path = self._get_memory_file_path(agent_name)
memory_data = self._load_memory_from_file(agent_name)
try:
mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Save memory data to file and update cache."""
file_path = self._get_memory_file_path(agent_name)
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
# Shallow-copy before adding lastUpdated so the caller's dict is not
# mutated as a side-effect, and the cache reference is not silently
# updated before the file write succeeds.
memory_data = {**memory_data, "lastUpdated": utc_now_iso_z()}
temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp")
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(memory_data, f, indent=2, ensure_ascii=False)
temp_path.replace(file_path)
try:
mtime = file_path.stat().st_mtime
except OSError:
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path)
return True
except OSError as e:
logger.error("Failed to save memory file: %s", e)
return False
_storage_instance: MemoryStorage | None = None
_storage_lock = threading.Lock()
def get_memory_storage() -> MemoryStorage:
"""Get the configured memory storage instance."""
global _storage_instance
if _storage_instance is not None:
return _storage_instance
with _storage_lock:
if _storage_instance is not None:
return _storage_instance
config = get_memory_config()
storage_class_path = config.storage_class
try:
module_path, class_name = storage_class_path.rsplit(".", 1)
import importlib
module = importlib.import_module(module_path)
storage_class = getattr(module, class_name)
# Validate that the configured storage is a MemoryStorage implementation
if not isinstance(storage_class, type):
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a class: {storage_class!r}")
if not issubclass(storage_class, MemoryStorage):
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
_storage_instance = storage_class()
except Exception as e:
logger.error(
"Failed to load memory storage %s, falling back to FileMemoryStorage: %s",
storage_class_path,
e,
)
_storage_instance = FileMemoryStorage()
return _storage_instance