deer-flow/backend/tests/test_memory_queue.py
rayhpeng 9af2f3e73c feat(memory): capture user_id at enqueue time for async-safe thread isolation
Add user_id field to ConversationContext and MemoryUpdateQueue.add() so the
user identity is stored explicitly at request time, before threading.Timer
fires on a different thread where ContextVar values do not propagate.
MemoryMiddleware.after_agent() now calls get_effective_user_id() at enqueue
time and passes the value through to updater.update_memory().

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-12 14:59:51 +08:00

94 lines
3.1 KiB
Python

from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
queue.add(thread_id="thread-1", messages=["second"], correction_detected=False)
assert len(queue._queue) == 1
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].correction_detected is True
def test_process_queue_forwards_correction_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue._queue = [
ConversationContext(
thread_id="thread-1",
messages=["conversation"],
agent_name="lead_agent",
correction_detected=True,
)
]
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
queue._process_queue()
mock_updater.update_memory.assert_called_once_with(
messages=["conversation"],
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=True,
reinforcement_detected=False,
user_id=None,
)
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
assert len(queue._queue) == 1
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].reinforcement_detected is True
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue._queue = [
ConversationContext(
thread_id="thread-1",
messages=["conversation"],
agent_name="lead_agent",
reinforcement_detected=True,
)
]
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
queue._process_queue()
mock_updater.update_memory.assert_called_once_with(
messages=["conversation"],
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=False,
reinforcement_detected=True,
user_id=None,
)