fix(middleware): avoid rescuing non-skill tool outputs during summarization (#2458)

* fix(middelware): narrow skill rescue to skill-related tool outputs

* fix(summarization): address skill rescue review feedback

* fix: wire summarization skill rescue config

* fix: remove dead skill tool helper

* fix(lint): fix format

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Nan Gao 2026-04-24 15:19:46 +02:00 committed by GitHub
parent c2332bb790
commit f9ff3a698d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 629 additions and 9 deletions

View File

@ -41,6 +41,13 @@ summarization:
# Custom summary prompt (optional) # Custom summary prompt (optional)
summary_prompt: null summary_prompt: null
# Tool names treated as skill file reads for skill rescue
skill_file_read_tool_names:
- read_file
- read
- view
- cat
``` ```
### Configuration Options ### Configuration Options
@ -125,6 +132,26 @@ keep:
- **Default**: `null` (uses LangChain's default prompt) - **Default**: `null` (uses LangChain's default prompt)
- **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context. - **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context.
#### `preserve_recent_skill_count`
- **Type**: Integer (≥ 0)
- **Default**: `5`
- **Description**: Number of most-recently-loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`, e.g. `/mnt/skills/...`) that are rescued from summarization. Prevents the agent from losing skill instructions after compression. Set to `0` to disable skill rescue entirely.
#### `preserve_recent_skill_tokens`
- **Type**: Integer (≥ 0)
- **Default**: `25000`
- **Description**: Total token budget reserved for rescued skill reads. Once this budget is exhausted, older skill bundles are allowed to be summarized.
#### `preserve_recent_skill_tokens_per_skill`
- **Type**: Integer (≥ 0)
- **Default**: `5000`
- **Description**: Per-skill token cap. Any individual skill read whose tool result exceeds this size is not rescued (it falls through to the summarizer like ordinary content).
#### `skill_file_read_tool_names`
- **Type**: List of strings
- **Default**: `["read_file", "read", "view", "cat"]`
- **Description**: Tool names treated as skill file reads during summarization rescue. A tool call is only eligible for skill rescue when its name appears in this list and its target path is under `skills.container_path`.
**Default Prompt Behavior:** **Default Prompt Behavior:**
The default LangChain prompt instructs the model to: The default LangChain prompt instructs the model to:
- Extract highest quality/most relevant context - Extract highest quality/most relevant context
@ -147,6 +174,7 @@ The default LangChain prompt instructs the model to:
- A single summary message is added - A single summary message is added
- Recent messages are preserved - Recent messages are preserved
6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together 6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together
7. **Skill Rescue**: Before the summary is generated, the most recently loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`) are lifted out of the summarization set and prepended to the preserved tail. Selection walks newest-first under three budgets: `preserve_recent_skill_count`, `preserve_recent_skill_tokens`, and `preserve_recent_skill_tokens_per_skill`. The triggering AIMessage and all of its paired ToolMessages move together so tool_call ↔ tool_result pairing stays intact.
### Token Counting ### Token Counting

View File

@ -84,7 +84,24 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
if get_memory_config().enabled: if get_memory_config().enabled:
hooks.append(memory_flush_hook) hooks.append(memory_flush_hook)
return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks) # The logic below relies on two assumptions holding true: this factory is
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup.
try:
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
except Exception:
logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills"
return DeerFlowSummarizationMiddleware(
**kwargs,
skills_container_path=skills_container_path,
skill_file_read_tool_names=config.skill_file_read_tool_names,
before_summarization=hooks,
preserve_recent_skill_count=config.preserve_recent_skill_count,
preserve_recent_skill_tokens=config.preserve_recent_skill_tokens,
preserve_recent_skill_tokens_per_skill=config.preserve_recent_skill_tokens_per_skill,
)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None: def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:

View File

@ -3,12 +3,13 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Collection
from dataclasses import dataclass from dataclasses import dataclass
from typing import Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AnyMessage, RemoveMessage from langchain_core.messages import AIMessage, AnyMessage, RemoveMessage, ToolMessage
from langgraph.config import get_config from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
@ -58,17 +59,63 @@ def _resolve_agent_name(runtime: Runtime) -> str | None:
return agent_name return agent_name
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
args = tool_call.get("args") or {}
if not isinstance(args, dict):
return None
for key in ("path", "file_path", "filepath"):
value = args.get(key)
if isinstance(value, str) and value:
return value
return None
def _clone_ai_message(
message: AIMessage,
tool_calls: list[dict[str, Any]],
*,
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
return message.model_copy(update=update)
@dataclass
class _SkillBundle:
"""Skill-related tool calls and tool results associated with one AIMessage."""
ai_index: int
skill_tool_indices: tuple[int, ...]
skill_tool_call_ids: frozenset[str]
skill_tool_tokens: int
skill_key: str
class DeerFlowSummarizationMiddleware(SummarizationMiddleware): class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""Summarization middleware with pre-compression hook dispatch.""" """Summarization middleware with pre-compression hook dispatch and skill rescue."""
def __init__( def __init__(
self, self,
*args, *args,
skills_container_path: str | None = None,
skill_file_read_tool_names: Collection[str] | None = None,
before_summarization: list[BeforeSummarizationHook] | None = None, before_summarization: list[BeforeSummarizationHook] | None = None,
preserve_recent_skill_count: int = 5,
preserve_recent_skill_tokens: int = 25_000,
preserve_recent_skill_tokens_per_skill: int = 5_000,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._skills_container_path = skills_container_path or "/mnt/skills"
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
self._before_summarization_hooks = before_summarization or [] self._before_summarization_hooks = before_summarization or []
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime) return self._maybe_summarize(state, runtime)
@ -88,7 +135,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
if cutoff_index <= 0: if cutoff_index <= 0:
return None return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime) self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize) summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary) new_messages = self._build_new_messages(summary)
@ -113,7 +160,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
if cutoff_index <= 0: if cutoff_index <= 0:
return None return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime) self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize) summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary) new_messages = self._build_new_messages(summary)
@ -126,6 +173,155 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
] ]
} }
def _partition_with_skill_rescue(
self,
messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Partition like the parent, then rescue recently-loaded skill bundles."""
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
return to_summarize, preserved
try:
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
except Exception:
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
return to_summarize, preserved
if not bundles:
return to_summarize, preserved
rescue_bundles = self._select_bundles_to_rescue(bundles)
if not rescue_bundles:
return to_summarize, preserved
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
rescued: list[AnyMessage] = []
remaining: list[AnyMessage] = []
for i, msg in enumerate(to_summarize):
bundle = bundles_by_ai_index.get(i)
if bundle is not None and isinstance(msg, AIMessage):
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
if rescued_tool_calls:
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
if remaining_tool_calls or msg.content:
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
continue
if i in rescue_tool_indices:
rescued.append(msg)
continue
remaining.append(msg)
return remaining, rescued + preserved
def _find_skill_bundles(
self,
messages: list[AnyMessage],
skills_root: str,
) -> list[_SkillBundle]:
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
bundles: list[_SkillBundle] = []
n = len(messages)
i = 0
while i < n:
msg = messages[i]
if not (isinstance(msg, AIMessage) and msg.tool_calls):
i += 1
continue
tool_calls = list(msg.tool_calls)
skill_paths_by_id: dict[str, str] = {}
for tc in tool_calls:
if self._is_skill_tool_call(tc, skills_root):
tc_id = tc.get("id")
path = _tool_call_path(tc)
if tc_id and path:
skill_paths_by_id[tc_id] = path
if not skill_paths_by_id:
i += 1
continue
skill_tool_tokens = 0
skill_key_parts: list[str] = []
skill_tool_indices: list[int] = []
matched_skill_call_ids: set[str] = set()
j = i + 1
while j < n and isinstance(messages[j], ToolMessage):
j += 1
for k in range(i + 1, j):
tool_msg = messages[k]
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
skill_tool_tokens += self.token_counter([tool_msg])
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
skill_tool_indices.append(k)
matched_skill_call_ids.add(tool_msg.tool_call_id)
if not skill_tool_indices:
i = j
continue
bundles.append(
_SkillBundle(
ai_index=i,
skill_tool_indices=tuple(skill_tool_indices),
skill_tool_call_ids=frozenset(matched_skill_call_ids),
skill_tool_tokens=skill_tool_tokens,
skill_key="|".join(sorted(skill_key_parts)),
)
)
i = j
return bundles
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
"""Pick bundles to keep, walking newest-first under count/token budgets."""
selected: list[_SkillBundle] = []
if not bundles:
return selected
seen_skill_keys: set[str] = set()
total_tokens = 0
kept = 0
for bundle in reversed(bundles):
if kept >= self._preserve_recent_skill_count:
break
if bundle.skill_key in seen_skill_keys:
continue
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
continue
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
continue
selected.append(bundle)
total_tokens += bundle.skill_tool_tokens
kept += 1
seen_skill_keys.add(bundle.skill_key)
selected.reverse()
return selected
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
"""Return True when ``tool_call`` reads a file under the configured skills root."""
name = tool_call.get("name") or ""
if name not in self._skill_file_read_tool_names:
return False
path = _tool_call_path(tool_call)
if not path:
return False
normalized_root = skills_root.rstrip("/")
return path == normalized_root or path.startswith(normalized_root + "/")
def _fire_hooks( def _fire_hooks(
self, self,
messages_to_summarize: list[AnyMessage], messages_to_summarize: list[AnyMessage],

View File

@ -51,6 +51,25 @@ class SummarizationConfig(BaseModel):
default=None, default=None,
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.", description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
) )
preserve_recent_skill_count: int = Field(
default=5,
ge=0,
description="Number of most-recently-loaded skill files to exclude from summarization. Set to 0 to disable skill preservation.",
)
preserve_recent_skill_tokens: int = Field(
default=25000,
ge=0,
description="Total token budget reserved for recently-loaded skill files that must be preserved across summarization.",
)
preserve_recent_skill_tokens_per_skill: int = Field(
default=5000,
ge=0,
description="Per-skill token cap when preserving skill files across summarization. Skill reads above this size are not rescued.",
)
skill_file_read_tool_names: list[str] = Field(
default_factory=lambda: ["read_file", "read", "view", "cat"],
description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.",
)
# Global configuration instance # Global configuration instance

View File

@ -207,3 +207,27 @@ def test_create_summarization_middleware_registers_memory_flush_hook_when_memory
lead_agent_module._create_summarization_middleware() lead_agent_module._create_summarization_middleware()
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook] assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
def test_create_summarization_middleware_passes_skill_read_tool_names(monkeypatch):
app_config = _make_app_config([_make_model("default-model", supports_thinking=False)])
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, skill_file_read_tool_names=["read_file", "cat"]),
)
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
captured: dict[str, object] = {}
def _fake_middleware(**kwargs):
captured.update(kwargs)
return kwargs
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
lead_agent_module._create_summarization_middleware()
assert captured["skill_file_read_tool_names"] == ["read_file", "cat"]

View File

@ -4,7 +4,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
from deerflow.agents.memory.summarization_hook import memory_flush_hook from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
@ -29,7 +29,16 @@ def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None)
return SimpleNamespace(context=context) return SimpleNamespace(context=context)
def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware: def _middleware(
*,
before_summarization=None,
trigger=("messages", 4),
keep=("messages", 2),
skill_file_read_tool_names=None,
preserve_recent_skill_count: int = 0,
preserve_recent_skill_tokens: int = 0,
preserve_recent_skill_tokens_per_skill: int = 0,
) -> DeerFlowSummarizationMiddleware:
model = MagicMock() model = MagicMock()
model.invoke.return_value = SimpleNamespace(text="compressed summary") model.invoke.return_value = SimpleNamespace(text="compressed summary")
return DeerFlowSummarizationMiddleware( return DeerFlowSummarizationMiddleware(
@ -38,9 +47,34 @@ def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("me
keep=keep, keep=keep,
token_counter=len, token_counter=len,
before_summarization=before_summarization, before_summarization=before_summarization,
skill_file_read_tool_names=skill_file_read_tool_names,
preserve_recent_skill_count=preserve_recent_skill_count,
preserve_recent_skill_tokens=preserve_recent_skill_tokens,
preserve_recent_skill_tokens_per_skill=preserve_recent_skill_tokens_per_skill,
) )
def _skill_read_call(tool_id: str, skill: str) -> dict:
return {
"name": "read_file",
"id": tool_id,
"args": {"path": f"/mnt/skills/public/{skill}/SKILL.md"},
}
def _skill_conversation() -> list:
return [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[_skill_read_call("t1", "alpha")]),
ToolMessage(content="alpha skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="", tool_calls=[_skill_read_call("t2", "beta")]),
ToolMessage(content="beta skill body", tool_call_id="t2"),
HumanMessage(content="u3"),
AIMessage(content="final"),
]
def test_before_summarization_hook_receives_messages_before_compression() -> None: def test_before_summarization_hook_receives_messages_before_compression() -> None:
captured: list[SummarizationEvent] = [] captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append]) middleware = _middleware(before_summarization=[captured.append])
@ -167,6 +201,295 @@ def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: p
assert add_kwargs["reinforcement_detected"] is False assert add_kwargs["reinforcement_detected"] is False
def test_skill_rescue_keeps_recent_skill_reads_out_of_summary() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
result = middleware.before_model({"messages": _skill_conversation()}, _runtime())
assert len(captured) == 1
summarized_ids = {id(m) for m in captured[0].messages_to_summarize}
preserved = captured[0].preserved_messages
# Both skill-read bundles should be rescued into preserved_messages,
# tool_call ↔ tool_result pairs stay intact.
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved)
for m in preserved:
if isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"}:
assert id(m) not in summarized_ids
# Preserved output order: rescued bundles first, then the tail kept by parent cutoff.
contents = [getattr(m, "content", None) for m in preserved]
assert contents[-2:] == ["u3", "final"]
# The final emitted state should start with RemoveMessage + summary, then preserved messages.
emitted = result["messages"]
assert isinstance(emitted[0], RemoveMessage)
assert emitted[1].content.startswith("Here is a summary")
assert list(emitted[-2:]) == list(preserved[-2:])
def test_skill_rescue_respects_count_budget() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=1,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
# Newest skill (beta) rescued; older skill (alpha) falls into summary.
assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in summarized)
def test_skill_rescue_uses_injected_skills_container_path() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware._skills_container_path = "/custom/skills"
messages = [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]),
ToolMessage(content="demo skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="final"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved)
def test_skill_rescue_uses_configured_skill_read_tool_names() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
skill_file_read_tool_names=["custom_read"],
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware._skills_container_path = "/custom/skills"
messages = [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[{"name": "custom_read", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]),
ToolMessage(content="demo skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="final"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved)
def test_skill_rescue_respects_per_skill_token_cap() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
# token_counter=len counts one token per message; per-skill cap of 0 rejects every bundle.
preserve_recent_skill_tokens_per_skill=0,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"} for m in preserved)
def test_skill_rescue_disabled_when_count_zero() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=0,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) for m in preserved)
def test_skill_rescue_ignores_non_skill_tool_reads() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/mnt/user-data/workspace/notes.md"}}],
),
ToolMessage(content="user notes", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved)
def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill and notes",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert preserved_ai.content == ""
assert summarized_ai.content == "reading skill and notes"
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
_skill_read_call("skill-2", "beta"),
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["skill-2"]
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and getattr(m, "tool_call_id", None) == "skill-2" for m in preserved)
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None: def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock() queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True)) monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))

View File

@ -12,7 +12,7 @@
# ============================================================================ # ============================================================================
# Bump this number when the config schema changes. # Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml. # Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 7 config_version: 8
# ============================================================================ # ============================================================================
# Logging # Logging
@ -726,6 +726,19 @@ summarization:
# The prompt should guide the model to extract important context # The prompt should guide the model to extract important context
summary_prompt: null summary_prompt: null
# Recently-loaded skill files are excluded from summarization so the agent
# does not lose skill instructions after a compression pass. Claude Code uses
# a similar strategy (keep the most recent ~5 skills, ~25k total tokens, with
# a ~5k cap per skill). Set preserve_recent_skill_count to 0 to disable.
preserve_recent_skill_count: 5
preserve_recent_skill_tokens: 25000
preserve_recent_skill_tokens_per_skill: 5000
skill_file_read_tool_names:
- read_file
- read
- view
- cat
# ============================================================================ # ============================================================================
# Memory Configuration # Memory Configuration
# ============================================================================ # ============================================================================