From 9af2f3e73c23a4656744a567b90c70aa33ed42d5 Mon Sep 17 00:00:00 2001 From: rayhpeng Date: Sun, 12 Apr 2026 14:59:51 +0800 Subject: [PATCH] feat(memory): capture user_id at enqueue time for async-safe thread isolation Add user_id field to ConversationContext and MemoryUpdateQueue.add() so the user identity is stored explicitly at request time, before threading.Timer fires on a different thread where ContextVar values do not propagate. MemoryMiddleware.after_agent() now calls get_effective_user_id() at enqueue time and passes the value through to updater.update_memory(). Co-Authored-By: Claude Sonnet 4.6 --- .../harness/deerflow/agents/memory/queue.py | 7 ++++ .../agents/middlewares/memory_middleware.py | 6 +++ backend/tests/test_memory_queue.py | 2 + .../tests/test_memory_queue_user_isolation.py | 38 +++++++++++++++++++ 4 files changed, 53 insertions(+) create mode 100644 backend/tests/test_memory_queue_user_isolation.py diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index 1db8c63dc..6de0bdcfc 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -20,6 +20,7 @@ class ConversationContext: messages: list[Any] timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) agent_name: str | None = None + user_id: str | None = None correction_detected: bool = False reinforcement_detected: bool = False @@ -44,6 +45,7 @@ class MemoryUpdateQueue: thread_id: str, messages: list[Any], agent_name: str | None = None, + user_id: str | None = None, correction_detected: bool = False, reinforcement_detected: bool = False, ) -> None: @@ -53,6 +55,9 @@ class MemoryUpdateQueue: thread_id: The thread ID. messages: The conversation messages. agent_name: If provided, memory is stored per-agent. If None, uses global memory. + user_id: The user ID captured at enqueue time. Stored in ConversationContext so it + survives the threading.Timer boundary (ContextVar does not propagate across + raw threads). correction_detected: Whether recent turns include an explicit correction signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal. """ @@ -71,6 +76,7 @@ class MemoryUpdateQueue: thread_id=thread_id, messages=messages, agent_name=agent_name, + user_id=user_id, correction_detected=merged_correction_detected, reinforcement_detected=merged_reinforcement_detected, ) @@ -136,6 +142,7 @@ class MemoryUpdateQueue: agent_name=context.agent_name, correction_detected=context.correction_detected, reinforcement_detected=context.reinforcement_detected, + user_id=context.user_id, ) if success: logger.info("Memory updated successfully for thread %s", context.thread_id) diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py index 5e8ca6344..7f239a89e 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py @@ -11,6 +11,7 @@ from langgraph.runtime import Runtime from deerflow.agents.memory.queue import get_memory_queue from deerflow.config.memory_config import get_memory_config +from deerflow.runtime.user_context import get_effective_user_id logger = logging.getLogger(__name__) @@ -236,11 +237,16 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): # Queue the filtered conversation for memory update correction_detected = detect_correction(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) + # Capture user_id at enqueue time while the request context is still alive. + # threading.Timer fires on a different thread where ContextVar values are not + # propagated, so we must store user_id explicitly in ConversationContext. + user_id = get_effective_user_id() queue = get_memory_queue() queue.add( thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name, + user_id=user_id, correction_detected=correction_detected, reinforcement_detected=reinforcement_detected, ) diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py index 204f9d16e..454cf2bf2 100644 --- a/backend/tests/test_memory_queue.py +++ b/backend/tests/test_memory_queue.py @@ -48,6 +48,7 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None: agent_name="lead_agent", correction_detected=True, reinforcement_detected=False, + user_id=None, ) @@ -88,4 +89,5 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None: agent_name="lead_agent", correction_detected=False, reinforcement_detected=True, + user_id=None, ) diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py new file mode 100644 index 000000000..1a209d659 --- /dev/null +++ b/backend/tests/test_memory_queue_user_isolation.py @@ -0,0 +1,38 @@ +"""Tests for user_id propagation through memory queue.""" +from unittest.mock import MagicMock, patch + +from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue + + +def test_conversation_context_has_user_id(): + ctx = ConversationContext(thread_id="t1", messages=[], user_id="alice") + assert ctx.user_id == "alice" + + +def test_conversation_context_user_id_default_none(): + ctx = ConversationContext(thread_id="t1", messages=[]) + assert ctx.user_id is None + + +def test_queue_add_stores_user_id(): + q = MemoryUpdateQueue() + with patch.object(q, "_reset_timer"): + q.add(thread_id="t1", messages=["msg"], user_id="alice") + assert len(q._queue) == 1 + assert q._queue[0].user_id == "alice" + q.clear() + + +def test_queue_process_passes_user_id_to_updater(): + q = MemoryUpdateQueue() + with patch.object(q, "_reset_timer"): + q.add(thread_id="t1", messages=["msg"], user_id="alice") + + mock_updater = MagicMock() + mock_updater.update_memory.return_value = True + with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater): + q._process_queue() + + mock_updater.update_memory.assert_called_once() + call_kwargs = mock_updater.update_memory.call_args.kwargs + assert call_kwargs["user_id"] == "alice"