mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-15 21:23:41 +00:00
fix(memory): isolate queued memory updates by agent (#2941)
* fix(memory): isolate queued memory updates by agent * fix(memory): include user in queue identity * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Fix the lint error --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
ba864112a3
commit
722c690f4f
@ -40,6 +40,15 @@ class MemoryUpdateQueue:
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
@staticmethod
|
||||
def _queue_key(
|
||||
thread_id: str,
|
||||
user_id: str | None,
|
||||
agent_name: str | None,
|
||||
) -> tuple[str, str | None, str | None]:
|
||||
"""Return the debounce identity for a memory update target."""
|
||||
return (thread_id, user_id, agent_name)
|
||||
|
||||
def add(
|
||||
self,
|
||||
thread_id: str,
|
||||
@ -115,8 +124,9 @@ class MemoryUpdateQueue:
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
) -> None:
|
||||
queue_key = self._queue_key(thread_id, user_id, agent_name)
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
@ -130,7 +140,7 @@ class MemoryUpdateQueue:
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
|
||||
self._queue.append(context)
|
||||
|
||||
def _reset_timer(self) -> None:
|
||||
|
||||
@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import resolve_runtime_user_id
|
||||
|
||||
|
||||
def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
user_id = resolve_runtime_user_id(event.runtime)
|
||||
queue = get_memory_queue()
|
||||
queue.add_nowait(
|
||||
thread_id=event.thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=event.agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None:
|
||||
assert elapsed < 0.1
|
||||
assert finished.is_set() is False
|
||||
assert finished.wait(1.0) is True
|
||||
|
||||
|
||||
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
||||
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
||||
|
||||
assert queue.pending_count == 2
|
||||
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
|
||||
|
||||
|
||||
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(
|
||||
thread_id="thread-1",
|
||||
messages=["first"],
|
||||
agent_name="agent-a",
|
||||
correction_detected=True,
|
||||
)
|
||||
queue.add(
|
||||
thread_id="thread-1",
|
||||
messages=["second"],
|
||||
agent_name="agent-a",
|
||||
correction_detected=False,
|
||||
)
|
||||
|
||||
assert queue.pending_count == 1
|
||||
assert queue._queue[0].agent_name == "agent-a"
|
||||
assert queue._queue[0].messages == ["second"]
|
||||
assert queue._queue[0].correction_detected is True
|
||||
|
||||
|
||||
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
||||
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
|
||||
patch("deerflow.agents.memory.queue.time.sleep"),
|
||||
):
|
||||
queue.flush()
|
||||
|
||||
assert mock_updater.update_memory.call_count == 2
|
||||
mock_updater.update_memory.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
messages=["agent-a"],
|
||||
thread_id="thread-1",
|
||||
agent_name="agent-a",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
),
|
||||
call(
|
||||
messages=["agent-b"],
|
||||
thread_id="thread-1",
|
||||
agent_name="agent-b",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@ -38,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater():
|
||||
mock_updater.update_memory.assert_called_once()
|
||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||
assert call_kwargs["user_id"] == "alice"
|
||||
|
||||
|
||||
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
||||
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
||||
|
||||
assert q.pending_count == 2
|
||||
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
||||
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
|
||||
|
||||
|
||||
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
|
||||
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
|
||||
|
||||
assert q.pending_count == 1
|
||||
assert q._queue[0].messages == ["second"]
|
||||
assert q._queue[0].user_id == "alice"
|
||||
assert q._queue[0].agent_name == "researcher"
|
||||
|
||||
|
||||
def test_add_nowait_keeps_different_users_separate():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
|
||||
patch.object(q, "_schedule_timer"),
|
||||
):
|
||||
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
||||
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
||||
|
||||
assert q.pending_count == 2
|
||||
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
||||
|
||||
@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||
def _runtime(
|
||||
thread_id: str | None = "thread-1",
|
||||
agent_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> SimpleNamespace:
|
||||
context = {}
|
||||
if thread_id is not None:
|
||||
context["thread_id"] = thread_id
|
||||
if agent_name is not None:
|
||||
context["agent_name"] = agent_name
|
||||
if user_id is not None:
|
||||
context["user_id"] = user_id
|
||||
return SimpleNamespace(context=context)
|
||||
|
||||
|
||||
@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||
|
||||
|
||||
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id="main",
|
||||
agent_name="researcher",
|
||||
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user