mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
* feat: flush memory before summarization * fix: keep agent-scoped memory on summarization flush * fix: harden summarization hook plumbing * fix: address summarization review feedback * style: format memory middleware
187 lines
6.9 KiB
Python
187 lines
6.9 KiB
Python
from __future__ import annotations
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
|
|
|
|
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
|
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
|
|
from deerflow.config.memory_config import MemoryConfig
|
|
|
|
|
|
def _messages() -> list:
|
|
return [
|
|
HumanMessage(content="user-1"),
|
|
AIMessage(content="assistant-1"),
|
|
HumanMessage(content="user-2"),
|
|
AIMessage(content="assistant-2"),
|
|
]
|
|
|
|
|
|
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
|
context = {}
|
|
if thread_id is not None:
|
|
context["thread_id"] = thread_id
|
|
if agent_name is not None:
|
|
context["agent_name"] = agent_name
|
|
return SimpleNamespace(context=context)
|
|
|
|
|
|
def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware:
|
|
model = MagicMock()
|
|
model.invoke.return_value = SimpleNamespace(text="compressed summary")
|
|
return DeerFlowSummarizationMiddleware(
|
|
model=model,
|
|
trigger=trigger,
|
|
keep=keep,
|
|
token_counter=len,
|
|
before_summarization=before_summarization,
|
|
)
|
|
|
|
|
|
def test_before_summarization_hook_receives_messages_before_compression() -> None:
|
|
captured: list[SummarizationEvent] = []
|
|
middleware = _middleware(before_summarization=[captured.append])
|
|
|
|
result = middleware.before_model({"messages": _messages()}, _runtime())
|
|
|
|
assert len(captured) == 1
|
|
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
|
|
assert [message.content for message in captured[0].preserved_messages] == ["user-2", "assistant-2"]
|
|
assert captured[0].thread_id == "thread-1"
|
|
assert captured[0].agent_name is None
|
|
assert isinstance(result["messages"][0], RemoveMessage)
|
|
assert result["messages"][1].content.startswith("Here is a summary")
|
|
|
|
|
|
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
|
|
captured: list[SummarizationEvent] = []
|
|
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
|
|
|
|
result = middleware.before_model({"messages": _messages()}, _runtime())
|
|
|
|
assert captured == []
|
|
assert result is None
|
|
|
|
|
|
def test_before_summarization_hook_exception_does_not_block_compression(caplog: pytest.LogCaptureFixture) -> None:
|
|
def _broken_hook(_: SummarizationEvent) -> None:
|
|
raise RuntimeError("hook failure")
|
|
|
|
middleware = _middleware(before_summarization=[_broken_hook])
|
|
|
|
with caplog.at_level("ERROR"):
|
|
result = middleware.before_model({"messages": _messages()}, _runtime())
|
|
|
|
assert "before_summarization hook _broken_hook failed" in caplog.text
|
|
assert isinstance(result["messages"][0], RemoveMessage)
|
|
|
|
|
|
def test_multiple_before_summarization_hooks_run_in_registration_order() -> None:
|
|
call_order: list[str] = []
|
|
|
|
def _hook(name: str):
|
|
return lambda _: call_order.append(name)
|
|
|
|
middleware = _middleware(before_summarization=[_hook("first"), _hook("second"), _hook("third")])
|
|
|
|
middleware.before_model({"messages": _messages()}, _runtime())
|
|
|
|
assert call_order == ["first", "second", "third"]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_abefore_model_calls_hooks_same_as_sync() -> None:
|
|
captured: list[SummarizationEvent] = []
|
|
middleware = _middleware(before_summarization=[captured.append])
|
|
|
|
await middleware.abefore_model({"messages": _messages()}, _runtime())
|
|
|
|
assert len(captured) == 1
|
|
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
|
|
|
|
|
|
def test_memory_flush_hook_skips_when_memory_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
queue = MagicMock()
|
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=False))
|
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
|
|
|
memory_flush_hook(
|
|
SummarizationEvent(
|
|
messages_to_summarize=tuple(_messages()[:2]),
|
|
preserved_messages=(),
|
|
thread_id="thread-1",
|
|
agent_name=None,
|
|
runtime=_runtime(),
|
|
)
|
|
)
|
|
|
|
queue.add_nowait.assert_not_called()
|
|
|
|
|
|
def test_memory_flush_hook_skips_when_thread_id_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
queue = MagicMock()
|
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
|
|
|
memory_flush_hook(
|
|
SummarizationEvent(
|
|
messages_to_summarize=tuple(_messages()[:2]),
|
|
preserved_messages=(),
|
|
thread_id=None,
|
|
agent_name=None,
|
|
runtime=_runtime(None),
|
|
)
|
|
)
|
|
|
|
queue.add_nowait.assert_not_called()
|
|
|
|
|
|
def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
queue = MagicMock()
|
|
messages = [
|
|
HumanMessage(content="Question"),
|
|
AIMessage(content="Calling tool", tool_calls=[{"name": "search", "id": "tool-1", "args": {}}]),
|
|
AIMessage(content="Final answer"),
|
|
]
|
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
|
|
|
memory_flush_hook(
|
|
SummarizationEvent(
|
|
messages_to_summarize=tuple(messages),
|
|
preserved_messages=(),
|
|
thread_id="thread-1",
|
|
agent_name=None,
|
|
runtime=_runtime(),
|
|
)
|
|
)
|
|
|
|
queue.add_nowait.assert_called_once()
|
|
add_kwargs = queue.add_nowait.call_args.kwargs
|
|
assert add_kwargs["thread_id"] == "thread-1"
|
|
assert [message.content for message in add_kwargs["messages"]] == ["Question", "Final answer"]
|
|
assert add_kwargs["correction_detected"] is False
|
|
assert add_kwargs["reinforcement_detected"] is False
|
|
|
|
|
|
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
queue = MagicMock()
|
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
|
|
|
memory_flush_hook(
|
|
SummarizationEvent(
|
|
messages_to_summarize=tuple(_messages()[:2]),
|
|
preserved_messages=(),
|
|
thread_id="thread-1",
|
|
agent_name="research-agent",
|
|
runtime=_runtime(agent_name="research-agent"),
|
|
)
|
|
)
|
|
|
|
queue.add_nowait.assert_called_once()
|
|
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|