deer-flow/backend/tests/test_summarization_middleware.py
Ryker_Feng d133b1119a
fix(summarization): tag summary LLM calls nostream to stop phantom stream messages (#2503) (#3378)
* fix(summarization): tag summary LLM calls nostream to stop phantom stream messages (#2503)

The SummarizationMiddleware runs its summary LLM call inside a before_model
hook. Without a nostream tag the summary tokens were captured by LangGraph's
messages-tuple stream callback and broadcast to the frontend as a phantom AI
message.

Generate a dedicated summary model copy tagged with "nostream" (merged on top
of any existing tags such as "middleware:summarize" so RunJournal attribution
is preserved) and override _create_summary / _acreate_summary to invoke it
directly. This avoids temporarily swapping the shared self.model, which would
otherwise leak the RunnableBinding across concurrent runs and break parent
logic that inspects the raw model (profile / _get_ls_params).

Add regression tests covering nostream tagging, concurrent-run isolation, raw
model preservation, and existing-tag merge.

* fix(summarization): address nostream review feedback
2026-06-07 17:55:04 +08:00

823 lines
32 KiB
Python

from __future__ import annotations
from types import SimpleNamespace
from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.constants import TAG_NOSTREAM
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
from deerflow.config.memory_config import MemoryConfig
def _messages() -> list:
return [
HumanMessage(content="user-1"),
AIMessage(content="assistant-1"),
HumanMessage(content="user-2"),
AIMessage(content="assistant-2"),
]
class _StaticChatModel(BaseChatModel):
text: str = "ok"
@property
def _llm_type(self) -> str:
return "static-test-chat-model"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=self.text))])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
return HumanMessage(
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
id=msg_id,
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
)
def _runtime(
thread_id: str | None = "thread-1",
agent_name: str | None = None,
user_id: str | None = None,
) -> SimpleNamespace:
context = {}
if thread_id is not None:
context["thread_id"] = thread_id
if agent_name is not None:
context["agent_name"] = agent_name
if user_id is not None:
context["user_id"] = user_id
return SimpleNamespace(context=context)
def _middleware(
*,
before_summarization=None,
trigger=("messages", 4),
keep=("messages", 2),
skill_file_read_tool_names=None,
preserve_recent_skill_count: int = 0,
preserve_recent_skill_tokens: int = 0,
preserve_recent_skill_tokens_per_skill: int = 0,
) -> DeerFlowSummarizationMiddleware:
model = MagicMock()
model.invoke.return_value = SimpleNamespace(text="compressed summary")
return DeerFlowSummarizationMiddleware(
model=model,
trigger=trigger,
keep=keep,
token_counter=len,
before_summarization=before_summarization,
skill_file_read_tool_names=skill_file_read_tool_names,
preserve_recent_skill_count=preserve_recent_skill_count,
preserve_recent_skill_tokens=preserve_recent_skill_tokens,
preserve_recent_skill_tokens_per_skill=preserve_recent_skill_tokens_per_skill,
)
def _skill_read_call(tool_id: str, skill: str) -> dict:
return {
"name": "read_file",
"id": tool_id,
"args": {"path": f"/mnt/skills/public/{skill}/SKILL.md"},
}
def _skill_conversation() -> list:
return [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[_skill_read_call("t1", "alpha")]),
ToolMessage(content="alpha skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="", tool_calls=[_skill_read_call("t2", "beta")]),
ToolMessage(content="beta skill body", tool_call_id="t2"),
HumanMessage(content="u3"),
AIMessage(content="final"),
]
def _raw_tool_call(tool_id: str, name: str = "read_file") -> dict:
return {
"id": tool_id,
"type": "function",
"function": {"name": name, "arguments": "{}"},
}
def test_before_summarization_hook_receives_messages_before_compression() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
result = middleware.before_model({"messages": _messages()}, _runtime())
assert len(captured) == 1
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
assert [message.content for message in captured[0].preserved_messages] == ["user-2", "assistant-2"]
assert captured[0].thread_id == "thread-1"
assert captured[0].agent_name is None
assert isinstance(result["messages"][0], RemoveMessage)
assert result["messages"][1].content.startswith("Here is a summary")
def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> None:
middleware = DeerFlowSummarizationMiddleware(
model=_StaticChatModel(text="compressed summary"),
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
agent = create_agent(
model=_StaticChatModel(text="done"),
tools=[],
middleware=[middleware],
)
chunks = list(agent.stream({"messages": _messages()}, stream_mode="updates"))
update = next(
(chunk["DeerFlowSummarizationMiddleware.before_model"] for chunk in chunks if "DeerFlowSummarizationMiddleware.before_model" in chunk),
None,
)
assert update is not None
emitted = update["messages"]
assert isinstance(emitted[0], RemoveMessage)
assert emitted[1].name == "summary"
assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary")
def test_summary_model_is_tagged_nostream_to_avoid_stream_pollution() -> None:
tags_during_summary: list[list[str]] = []
class _RecordingChatModel(_StaticChatModel):
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
tags_during_summary.append(list(run_manager.tags) if run_manager else [])
return super()._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
model = _RecordingChatModel(text="compressed summary")
middleware = DeerFlowSummarizationMiddleware(
model=model,
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
# The dedicated summary model must carry TAG_NOSTREAM so LangGraph's
# messages-tuple stream handler skips its tokens, while the raw model used by
# the parent for profile / token inspection stays untagged.
assert TAG_NOSTREAM in (middleware._summary_model.config.get("tags") or [])
assert TAG_NOSTREAM not in (getattr(middleware.model, "config", {}).get("tags") or [])
result = middleware.before_model({"messages": _messages()}, _runtime())
# The summary LLM call must actually run with the nostream tag (this is what the
# stream handler inspects), and the shared self.model must remain the raw,
# untagged model so parent logic (profile / _get_ls_params) keeps working.
assert tags_during_summary == [[TAG_NOSTREAM]]
assert middleware.model is model
assert result["messages"][1].content.startswith("Here is a summary")
def test_summarization_does_not_mutate_shared_model_across_concurrent_runs() -> None:
"""Concurrent runs must not observe a swapped-out self.model during summarization.
The agent/middleware instance is cached and reused, so summarization must never
temporarily replace the shared self.model: doing so would leak the nostream
RunnableBinding to other coroutines mid-flight and break parent logic that
inspects the raw model (profile / _get_ls_params).
"""
import asyncio
observed_models: list[object] = []
started = asyncio.Event()
release = asyncio.Event()
class _BlockingChatModel(_StaticChatModel):
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
# Hold the summary call open so a concurrent run can inspect self.model.
started.set()
await release.wait()
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
model = _BlockingChatModel(text="compressed summary")
middleware = DeerFlowSummarizationMiddleware(
model=model,
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
async def _run() -> None:
summarizing = asyncio.create_task(middleware.abefore_model({"messages": _messages()}, _runtime()))
# Wait until the summary task reaches the blocked LLM call.
await started.wait()
# A concurrent run reads the shared model while summarization is in flight.
observed_models.append(middleware.model)
release.set()
await summarizing
asyncio.run(_run())
assert observed_models == [model]
def test_raw_model_is_preserved_for_parent_profile_inspection() -> None:
"""self.model must stay the original model so attribute access does not drift."""
model = _StaticChatModel(text="compressed summary")
middleware = DeerFlowSummarizationMiddleware(
model=model,
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
middleware.before_model({"messages": _messages()}, _runtime())
# The shared field is never reassigned to the RunnableBinding.
assert middleware.model is model
assert middleware._summary_model is not model
def test_summary_model_preserves_existing_tags_when_adding_nostream() -> None:
"""Adding TAG_NOSTREAM must not clobber tags already bound on the model.
lead_agent/agent.py binds "middleware:summarize" for RunJournal attribution. Because
RunnableBinding.with_config shallow-merges config, the summary model must explicitly
preserve existing tags instead of overwriting them with just [TAG_NOSTREAM].
"""
tagged_model = _StaticChatModel(text="compressed summary").with_config(tags=["middleware:summarize"])
middleware = DeerFlowSummarizationMiddleware(
model=tagged_model,
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
summary_tags = middleware._summary_model.config.get("tags") or []
assert "middleware:summarize" in summary_tags
assert TAG_NOSTREAM in summary_tags
# No duplicate TAG_NOSTREAM even if invoked when one was already present.
assert summary_tags.count(TAG_NOSTREAM) == 1
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
reminder = _dynamic_context_reminder()
result = middleware.before_model(
{
"messages": [
reminder,
HumanMessage(content="user-1"),
AIMessage(content="assistant-1"),
HumanMessage(content="user-2"),
]
},
_runtime(),
)
assert len(captured) == 1
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1"]
assert captured[0].preserved_messages[0] is reminder
emitted = result["messages"]
assert isinstance(emitted[0], RemoveMessage)
assert emitted[1].name == "summary"
assert emitted[2] is reminder
followup_state = {"messages": [*emitted[1:], HumanMessage(content="Follow-up", id="msg-2")]}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
assert DynamicContextMiddleware().before_agent(followup_state, _runtime()) is None
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
result = middleware.before_model({"messages": _messages()}, _runtime())
assert captured == []
assert result is None
def test_before_summarization_hook_exception_does_not_block_compression(caplog: pytest.LogCaptureFixture) -> None:
def _broken_hook(_: SummarizationEvent) -> None:
raise RuntimeError("hook failure")
middleware = _middleware(before_summarization=[_broken_hook])
with caplog.at_level("ERROR"):
result = middleware.before_model({"messages": _messages()}, _runtime())
assert "before_summarization hook _broken_hook failed" in caplog.text
assert isinstance(result["messages"][0], RemoveMessage)
def test_multiple_before_summarization_hooks_run_in_registration_order() -> None:
call_order: list[str] = []
def _hook(name: str):
return lambda _: call_order.append(name)
middleware = _middleware(before_summarization=[_hook("first"), _hook("second"), _hook("third")])
middleware.before_model({"messages": _messages()}, _runtime())
assert call_order == ["first", "second", "third"]
@pytest.mark.anyio
async def test_abefore_model_calls_hooks_same_as_sync() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
await middleware.abefore_model({"messages": _messages()}, _runtime())
assert len(captured) == 1
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
def test_memory_flush_hook_skips_when_memory_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=False))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="thread-1",
agent_name=None,
runtime=_runtime(),
)
)
queue.add_nowait.assert_not_called()
def test_memory_flush_hook_skips_when_thread_id_missing(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id=None,
agent_name=None,
runtime=_runtime(None),
)
)
queue.add_nowait.assert_not_called()
def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
messages = [
HumanMessage(content="Question"),
AIMessage(content="Calling tool", tool_calls=[{"name": "search", "id": "tool-1", "args": {}}]),
AIMessage(content="Final answer"),
]
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(messages),
preserved_messages=(),
thread_id="thread-1",
agent_name=None,
runtime=_runtime(),
)
)
queue.add_nowait.assert_called_once()
add_kwargs = queue.add_nowait.call_args.kwargs
assert add_kwargs["thread_id"] == "thread-1"
assert [message.content for message in add_kwargs["messages"]] == ["Question", "Final answer"]
assert add_kwargs["correction_detected"] is False
assert add_kwargs["reinforcement_detected"] is False
def test_skill_rescue_keeps_recent_skill_reads_out_of_summary() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
result = middleware.before_model({"messages": _skill_conversation()}, _runtime())
assert len(captured) == 1
summarized_ids = {id(m) for m in captured[0].messages_to_summarize}
preserved = captured[0].preserved_messages
# Both skill-read bundles should be rescued into preserved_messages,
# tool_call ↔ tool_result pairs stay intact.
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved)
for m in preserved:
if isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"}:
assert id(m) not in summarized_ids
# Preserved output order: rescued bundles first, then the tail kept by parent cutoff.
contents = [getattr(m, "content", None) for m in preserved]
assert contents[-2:] == ["u3", "final"]
# The final emitted state should start with RemoveMessage + summary, then preserved messages.
emitted = result["messages"]
assert isinstance(emitted[0], RemoveMessage)
assert emitted[1].content.startswith("Here is a summary")
assert list(emitted[-2:]) == list(preserved[-2:])
def test_skill_rescue_respects_count_budget() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=1,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
# Newest skill (beta) rescued; older skill (alpha) falls into summary.
assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in summarized)
def test_skill_rescue_uses_injected_skills_container_path() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware._skills_container_path = "/custom/skills"
messages = [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]),
ToolMessage(content="demo skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="final"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved)
def test_skill_rescue_uses_configured_skill_read_tool_names() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
skill_file_read_tool_names=["custom_read"],
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware._skills_container_path = "/custom/skills"
messages = [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[{"name": "custom_read", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]),
ToolMessage(content="demo skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="final"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved)
def test_skill_rescue_respects_per_skill_token_cap() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
# token_counter=len counts one token per message; per-skill cap of 0 rejects every bundle.
preserve_recent_skill_tokens_per_skill=0,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"} for m in preserved)
def test_skill_rescue_disabled_when_count_zero() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=0,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) for m in preserved)
def test_skill_rescue_ignores_non_skill_tool_reads() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/mnt/user-data/workspace/notes.md"}}],
),
ToolMessage(content="user notes", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved)
def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
def test_skill_rescue_syncs_raw_provider_tool_calls_on_split_ai_messages() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill and notes",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1"), _raw_tool_call("file-1")]},
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in preserved_ai.additional_kwargs["tool_calls"]] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
assert [tc["id"] for tc in summarized_ai.additional_kwargs["tool_calls"]] == ["file-1"]
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill and notes",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert preserved_ai.content == ""
assert summarized_ai.content == "reading skill and notes"
def test_skill_rescue_removes_raw_provider_tool_calls_from_content_only_summary_clone() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill",
tool_calls=[_skill_read_call("skill-1", "alpha")],
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1")], "function_call": {"name": "read_file"}},
response_metadata={"finish_reason": "tool_calls"},
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
summarized = captured[0].messages_to_summarize
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage))
assert summarized_ai.content == "reading skill"
assert summarized_ai.tool_calls == []
assert "tool_calls" not in summarized_ai.additional_kwargs
assert "function_call" not in summarized_ai.additional_kwargs
assert summarized_ai.response_metadata["finish_reason"] == "stop"
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
_skill_read_call("skill-2", "beta"),
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["skill-2"]
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and getattr(m, "tool_call_id", None) == "skill-2" for m in preserved)
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="thread-1",
agent_name="research-agent",
runtime=_runtime(agent_name="research-agent"),
)
)
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="main",
agent_name="researcher",
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
)
)
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"