From f9ff3a698ddc64dc8dbc7404e0a2f7ef886ef0f8 Mon Sep 17 00:00:00 2001 From: Nan Gao <88081804+ggnnggez@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:19:46 +0200 Subject: [PATCH] fix(middleware): avoid rescuing non-skill tool outputs during summarization (#2458) * fix(middelware): narrow skill rescue to skill-related tool outputs * fix(summarization): address skill rescue review feedback * fix: wire summarization skill rescue config * fix: remove dead skill tool helper * fix(lint): fix format --------- Co-authored-by: Willem Jiang --- backend/docs/summarization.md | 28 ++ .../deerflow/agents/lead_agent/agent.py | 19 +- .../middlewares/summarization_middleware.py | 206 ++++++++++- .../deerflow/config/summarization_config.py | 19 + .../tests/test_lead_agent_model_resolution.py | 24 ++ .../tests/test_summarization_middleware.py | 327 +++++++++++++++++- config.example.yaml | 15 +- 7 files changed, 629 insertions(+), 9 deletions(-) diff --git a/backend/docs/summarization.md b/backend/docs/summarization.md index ca1e8dda1..773d27e3d 100644 --- a/backend/docs/summarization.md +++ b/backend/docs/summarization.md @@ -41,6 +41,13 @@ summarization: # Custom summary prompt (optional) summary_prompt: null + + # Tool names treated as skill file reads for skill rescue + skill_file_read_tool_names: + - read_file + - read + - view + - cat ``` ### Configuration Options @@ -125,6 +132,26 @@ keep: - **Default**: `null` (uses LangChain's default prompt) - **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context. +#### `preserve_recent_skill_count` +- **Type**: Integer (≥ 0) +- **Default**: `5` +- **Description**: Number of most-recently-loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`, e.g. `/mnt/skills/...`) that are rescued from summarization. Prevents the agent from losing skill instructions after compression. Set to `0` to disable skill rescue entirely. + +#### `preserve_recent_skill_tokens` +- **Type**: Integer (≥ 0) +- **Default**: `25000` +- **Description**: Total token budget reserved for rescued skill reads. Once this budget is exhausted, older skill bundles are allowed to be summarized. + +#### `preserve_recent_skill_tokens_per_skill` +- **Type**: Integer (≥ 0) +- **Default**: `5000` +- **Description**: Per-skill token cap. Any individual skill read whose tool result exceeds this size is not rescued (it falls through to the summarizer like ordinary content). + +#### `skill_file_read_tool_names` +- **Type**: List of strings +- **Default**: `["read_file", "read", "view", "cat"]` +- **Description**: Tool names treated as skill file reads during summarization rescue. A tool call is only eligible for skill rescue when its name appears in this list and its target path is under `skills.container_path`. + **Default Prompt Behavior:** The default LangChain prompt instructs the model to: - Extract highest quality/most relevant context @@ -147,6 +174,7 @@ The default LangChain prompt instructs the model to: - A single summary message is added - Recent messages are preserved 6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together +7. **Skill Rescue**: Before the summary is generated, the most recently loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`) are lifted out of the summarization set and prepended to the preserved tail. Selection walks newest-first under three budgets: `preserve_recent_skill_count`, `preserve_recent_skill_tokens`, and `preserve_recent_skill_tokens_per_skill`. The triggering AIMessage and all of its paired ToolMessages move together so tool_call ↔ tool_result pairing stays intact. ### Token Counting diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index de3ff6766..f17aab6ce 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -84,7 +84,24 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None if get_memory_config().enabled: hooks.append(memory_flush_hook) - return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks) + # The logic below relies on two assumptions holding true: this factory is + # the sole entry point for DeerFlowSummarizationMiddleware, and the runtime + # config is not expected to change after startup. + try: + skills_container_path = get_app_config().skills.container_path or "/mnt/skills" + except Exception: + logger.exception("Failed to resolve skills container path; falling back to default") + skills_container_path = "/mnt/skills" + + return DeerFlowSummarizationMiddleware( + **kwargs, + skills_container_path=skills_container_path, + skill_file_read_tool_names=config.skill_file_read_tool_names, + before_summarization=hooks, + preserve_recent_skill_count=config.preserve_recent_skill_count, + preserve_recent_skill_tokens=config.preserve_recent_skill_tokens, + preserve_recent_skill_tokens_per_skill=config.preserve_recent_skill_tokens_per_skill, + ) def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None: diff --git a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py index fba44c215..651b64a72 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py @@ -3,12 +3,13 @@ from __future__ import annotations import logging +from collections.abc import Collection from dataclasses import dataclass -from typing import Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from langchain.agents import AgentState from langchain.agents.middleware import SummarizationMiddleware -from langchain_core.messages import AnyMessage, RemoveMessage +from langchain_core.messages import AIMessage, AnyMessage, RemoveMessage, ToolMessage from langgraph.config import get_config from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.runtime import Runtime @@ -58,17 +59,63 @@ def _resolve_agent_name(runtime: Runtime) -> str | None: return agent_name +def _tool_call_path(tool_call: dict[str, Any]) -> str | None: + """Best-effort extraction of a file path argument from a read_file-like tool call.""" + args = tool_call.get("args") or {} + if not isinstance(args, dict): + return None + for key in ("path", "file_path", "filepath"): + value = args.get(key) + if isinstance(value, str) and value: + return value + return None + + +def _clone_ai_message( + message: AIMessage, + tool_calls: list[dict[str, Any]], + *, + content: Any | None = None, +) -> AIMessage: + """Clone an AIMessage while replacing its tool_calls list and optional content.""" + update: dict[str, Any] = {"tool_calls": tool_calls} + if content is not None: + update["content"] = content + return message.model_copy(update=update) + + +@dataclass +class _SkillBundle: + """Skill-related tool calls and tool results associated with one AIMessage.""" + + ai_index: int + skill_tool_indices: tuple[int, ...] + skill_tool_call_ids: frozenset[str] + skill_tool_tokens: int + skill_key: str + + class DeerFlowSummarizationMiddleware(SummarizationMiddleware): - """Summarization middleware with pre-compression hook dispatch.""" + """Summarization middleware with pre-compression hook dispatch and skill rescue.""" def __init__( self, *args, + skills_container_path: str | None = None, + skill_file_read_tool_names: Collection[str] | None = None, before_summarization: list[BeforeSummarizationHook] | None = None, + preserve_recent_skill_count: int = 5, + preserve_recent_skill_tokens: int = 25_000, + preserve_recent_skill_tokens_per_skill: int = 5_000, **kwargs, ) -> None: super().__init__(*args, **kwargs) + self._skills_container_path = skills_container_path or "/mnt/skills" + self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"}) self._before_summarization_hooks = before_summarization or [] + self._preserve_recent_skill_count = max(0, preserve_recent_skill_count) + self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens) + self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill) def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._maybe_summarize(state, runtime) @@ -88,7 +135,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware): if cutoff_index <= 0: return None - messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) + messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index) self._fire_hooks(messages_to_summarize, preserved_messages, runtime) summary = self._create_summary(messages_to_summarize) new_messages = self._build_new_messages(summary) @@ -113,7 +160,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware): if cutoff_index <= 0: return None - messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) + messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index) self._fire_hooks(messages_to_summarize, preserved_messages, runtime) summary = await self._acreate_summary(messages_to_summarize) new_messages = self._build_new_messages(summary) @@ -126,6 +173,155 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware): ] } + def _partition_with_skill_rescue( + self, + messages: list[AnyMessage], + cutoff_index: int, + ) -> tuple[list[AnyMessage], list[AnyMessage]]: + """Partition like the parent, then rescue recently-loaded skill bundles.""" + to_summarize, preserved = self._partition_messages(messages, cutoff_index) + + if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize: + return to_summarize, preserved + + try: + bundles = self._find_skill_bundles(to_summarize, self._skills_container_path) + except Exception: + logger.exception("Skill-preserving summarization rescue failed; falling back to default partition") + return to_summarize, preserved + + if not bundles: + return to_summarize, preserved + + rescue_bundles = self._select_bundles_to_rescue(bundles) + if not rescue_bundles: + return to_summarize, preserved + + bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles} + rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices} + rescued: list[AnyMessage] = [] + remaining: list[AnyMessage] = [] + for i, msg in enumerate(to_summarize): + bundle = bundles_by_ai_index.get(i) + if bundle is not None and isinstance(msg, AIMessage): + rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids] + remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids] + + if rescued_tool_calls: + rescued.append(_clone_ai_message(msg, rescued_tool_calls, content="")) + if remaining_tool_calls or msg.content: + remaining.append(_clone_ai_message(msg, remaining_tool_calls)) + continue + + if i in rescue_tool_indices: + rescued.append(msg) + continue + + remaining.append(msg) + + return remaining, rescued + preserved + + def _find_skill_bundles( + self, + messages: list[AnyMessage], + skills_root: str, + ) -> list[_SkillBundle]: + """Locate AIMessage + paired ToolMessage groups that load skill files.""" + bundles: list[_SkillBundle] = [] + n = len(messages) + i = 0 + while i < n: + msg = messages[i] + if not (isinstance(msg, AIMessage) and msg.tool_calls): + i += 1 + continue + + tool_calls = list(msg.tool_calls) + skill_paths_by_id: dict[str, str] = {} + for tc in tool_calls: + if self._is_skill_tool_call(tc, skills_root): + tc_id = tc.get("id") + path = _tool_call_path(tc) + if tc_id and path: + skill_paths_by_id[tc_id] = path + + if not skill_paths_by_id: + i += 1 + continue + + skill_tool_tokens = 0 + skill_key_parts: list[str] = [] + skill_tool_indices: list[int] = [] + matched_skill_call_ids: set[str] = set() + + j = i + 1 + while j < n and isinstance(messages[j], ToolMessage): + j += 1 + + for k in range(i + 1, j): + tool_msg = messages[k] + if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id: + skill_tool_tokens += self.token_counter([tool_msg]) + skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id]) + skill_tool_indices.append(k) + matched_skill_call_ids.add(tool_msg.tool_call_id) + + if not skill_tool_indices: + i = j + continue + + bundles.append( + _SkillBundle( + ai_index=i, + skill_tool_indices=tuple(skill_tool_indices), + skill_tool_call_ids=frozenset(matched_skill_call_ids), + skill_tool_tokens=skill_tool_tokens, + skill_key="|".join(sorted(skill_key_parts)), + ) + ) + i = j + + return bundles + + def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]: + """Pick bundles to keep, walking newest-first under count/token budgets.""" + selected: list[_SkillBundle] = [] + if not bundles: + return selected + + seen_skill_keys: set[str] = set() + total_tokens = 0 + kept = 0 + + for bundle in reversed(bundles): + if kept >= self._preserve_recent_skill_count: + break + if bundle.skill_key in seen_skill_keys: + continue + if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill: + continue + if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens: + continue + + selected.append(bundle) + total_tokens += bundle.skill_tool_tokens + kept += 1 + seen_skill_keys.add(bundle.skill_key) + + selected.reverse() + return selected + + def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool: + """Return True when ``tool_call`` reads a file under the configured skills root.""" + name = tool_call.get("name") or "" + if name not in self._skill_file_read_tool_names: + return False + path = _tool_call_path(tool_call) + if not path: + return False + normalized_root = skills_root.rstrip("/") + return path == normalized_root or path.startswith(normalized_root + "/") + def _fire_hooks( self, messages_to_summarize: list[AnyMessage], diff --git a/backend/packages/harness/deerflow/config/summarization_config.py b/backend/packages/harness/deerflow/config/summarization_config.py index f132e58cd..fab268ec5 100644 --- a/backend/packages/harness/deerflow/config/summarization_config.py +++ b/backend/packages/harness/deerflow/config/summarization_config.py @@ -51,6 +51,25 @@ class SummarizationConfig(BaseModel): default=None, description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.", ) + preserve_recent_skill_count: int = Field( + default=5, + ge=0, + description="Number of most-recently-loaded skill files to exclude from summarization. Set to 0 to disable skill preservation.", + ) + preserve_recent_skill_tokens: int = Field( + default=25000, + ge=0, + description="Total token budget reserved for recently-loaded skill files that must be preserved across summarization.", + ) + preserve_recent_skill_tokens_per_skill: int = Field( + default=5000, + ge=0, + description="Per-skill token cap when preserving skill files across summarization. Skill reads above this size are not rescued.", + ) + skill_file_read_tool_names: list[str] = Field( + default_factory=lambda: ["read_file", "read", "view", "cat"], + description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.", + ) # Global configuration instance diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index 12a4d0143..dc95dc4da 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -207,3 +207,27 @@ def test_create_summarization_middleware_registers_memory_flush_hook_when_memory lead_agent_module._create_summarization_middleware() assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook] + + +def test_create_summarization_middleware_passes_skill_read_tool_names(monkeypatch): + app_config = _make_app_config([_make_model("default-model", supports_thinking=False)]) + monkeypatch.setattr( + lead_agent_module, + "get_summarization_config", + lambda: SummarizationConfig(enabled=True, skill_file_read_tool_names=["read_file", "cat"]), + ) + monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False)) + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object()) + + captured: dict[str, object] = {} + + def _fake_middleware(**kwargs): + captured.update(kwargs) + return kwargs + + monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware) + + lead_agent_module._create_summarization_middleware() + + assert captured["skill_file_read_tool_names"] == ["read_file", "cat"] diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py index d327c94c4..79ca8b01c 100644 --- a/backend/tests/test_summarization_middleware.py +++ b/backend/tests/test_summarization_middleware.py @@ -4,7 +4,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage +from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage from deerflow.agents.memory.summarization_hook import memory_flush_hook from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent @@ -29,7 +29,16 @@ def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) return SimpleNamespace(context=context) -def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware: +def _middleware( + *, + before_summarization=None, + trigger=("messages", 4), + keep=("messages", 2), + skill_file_read_tool_names=None, + preserve_recent_skill_count: int = 0, + preserve_recent_skill_tokens: int = 0, + preserve_recent_skill_tokens_per_skill: int = 0, +) -> DeerFlowSummarizationMiddleware: model = MagicMock() model.invoke.return_value = SimpleNamespace(text="compressed summary") return DeerFlowSummarizationMiddleware( @@ -38,9 +47,34 @@ def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("me keep=keep, token_counter=len, before_summarization=before_summarization, + skill_file_read_tool_names=skill_file_read_tool_names, + preserve_recent_skill_count=preserve_recent_skill_count, + preserve_recent_skill_tokens=preserve_recent_skill_tokens, + preserve_recent_skill_tokens_per_skill=preserve_recent_skill_tokens_per_skill, ) +def _skill_read_call(tool_id: str, skill: str) -> dict: + return { + "name": "read_file", + "id": tool_id, + "args": {"path": f"/mnt/skills/public/{skill}/SKILL.md"}, + } + + +def _skill_conversation() -> list: + return [ + HumanMessage(content="u1"), + AIMessage(content="", tool_calls=[_skill_read_call("t1", "alpha")]), + ToolMessage(content="alpha skill body", tool_call_id="t1"), + HumanMessage(content="u2"), + AIMessage(content="", tool_calls=[_skill_read_call("t2", "beta")]), + ToolMessage(content="beta skill body", tool_call_id="t2"), + HumanMessage(content="u3"), + AIMessage(content="final"), + ] + + def test_before_summarization_hook_receives_messages_before_compression() -> None: captured: list[SummarizationEvent] = [] middleware = _middleware(before_summarization=[captured.append]) @@ -167,6 +201,295 @@ def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: p assert add_kwargs["reinforcement_detected"] is False +def test_skill_rescue_keeps_recent_skill_reads_out_of_summary() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + result = middleware.before_model({"messages": _skill_conversation()}, _runtime()) + + assert len(captured) == 1 + summarized_ids = {id(m) for m in captured[0].messages_to_summarize} + preserved = captured[0].preserved_messages + + # Both skill-read bundles should be rescued into preserved_messages, + # tool_call ↔ tool_result pairs stay intact. + assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved) + assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved) + for m in preserved: + if isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"}: + assert id(m) not in summarized_ids + + # Preserved output order: rescued bundles first, then the tail kept by parent cutoff. + contents = [getattr(m, "content", None) for m in preserved] + assert contents[-2:] == ["u3", "final"] + + # The final emitted state should start with RemoveMessage + summary, then preserved messages. + emitted = result["messages"] + assert isinstance(emitted[0], RemoveMessage) + assert emitted[1].content.startswith("Here is a summary") + assert list(emitted[-2:]) == list(preserved[-2:]) + + +def test_skill_rescue_respects_count_budget() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=1, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + middleware.before_model({"messages": _skill_conversation()}, _runtime()) + + preserved = captured[0].preserved_messages + summarized = captured[0].messages_to_summarize + # Newest skill (beta) rescued; older skill (alpha) falls into summary. + assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved) + assert not any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved) + assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in summarized) + + +def test_skill_rescue_uses_injected_skills_container_path() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + middleware._skills_container_path = "/custom/skills" + messages = [ + HumanMessage(content="u1"), + AIMessage(content="", tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]), + ToolMessage(content="demo skill body", tool_call_id="t1"), + HumanMessage(content="u2"), + AIMessage(content="final"), + ] + + middleware.before_model({"messages": messages}, _runtime()) + + preserved = captured[0].preserved_messages + assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved) + + +def test_skill_rescue_uses_configured_skill_read_tool_names() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + skill_file_read_tool_names=["custom_read"], + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + middleware._skills_container_path = "/custom/skills" + messages = [ + HumanMessage(content="u1"), + AIMessage(content="", tool_calls=[{"name": "custom_read", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]), + ToolMessage(content="demo skill body", tool_call_id="t1"), + HumanMessage(content="u2"), + AIMessage(content="final"), + ] + + middleware.before_model({"messages": messages}, _runtime()) + + preserved = captured[0].preserved_messages + assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved) + + +def test_skill_rescue_respects_per_skill_token_cap() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + # token_counter=len counts one token per message; per-skill cap of 0 rejects every bundle. + preserve_recent_skill_tokens_per_skill=0, + ) + + middleware.before_model({"messages": _skill_conversation()}, _runtime()) + + preserved = captured[0].preserved_messages + assert not any(isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"} for m in preserved) + + +def test_skill_rescue_disabled_when_count_zero() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=0, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + middleware.before_model({"messages": _skill_conversation()}, _runtime()) + + preserved = captured[0].preserved_messages + assert not any(isinstance(m, ToolMessage) for m in preserved) + + +def test_skill_rescue_ignores_non_skill_tool_reads() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + messages = [ + HumanMessage(content="u1"), + AIMessage( + content="", + tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/mnt/user-data/workspace/notes.md"}}], + ), + ToolMessage(content="user notes", tool_call_id="t1"), + HumanMessage(content="u2"), + AIMessage(content="done"), + ] + + middleware.before_model({"messages": messages}, _runtime()) + + preserved = captured[0].preserved_messages + assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved) + + +def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + messages = [ + HumanMessage(content="u1"), + AIMessage( + content="", + tool_calls=[ + _skill_read_call("skill-1", "alpha"), + {"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}}, + ], + ), + ToolMessage(content="alpha skill body", tool_call_id="skill-1"), + ToolMessage(content="user notes", tool_call_id="file-1"), + HumanMessage(content="u2"), + AIMessage(content="done"), + ] + + middleware.before_model({"messages": messages}, _runtime()) + + preserved = captured[0].preserved_messages + summarized = captured[0].messages_to_summarize + + preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls) + summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls) + + assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"] + assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"] + assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved) + assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved) + assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized) + + +def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + messages = [ + HumanMessage(content="u1"), + AIMessage( + content="reading skill and notes", + tool_calls=[ + _skill_read_call("skill-1", "alpha"), + {"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}}, + ], + ), + ToolMessage(content="alpha skill body", tool_call_id="skill-1"), + ToolMessage(content="user notes", tool_call_id="file-1"), + HumanMessage(content="u2"), + AIMessage(content="done"), + ] + + middleware.before_model({"messages": messages}, _runtime()) + + preserved = captured[0].preserved_messages + summarized = captured[0].messages_to_summarize + + preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls) + summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls) + + assert preserved_ai.content == "" + assert summarized_ai.content == "reading skill and notes" + + +def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + messages = [ + HumanMessage(content="u1"), + AIMessage( + content="", + tool_calls=[ + _skill_read_call("skill-1", "alpha"), + _skill_read_call("skill-2", "beta"), + ], + ), + ToolMessage(content="alpha skill body", tool_call_id="skill-1"), + HumanMessage(content="u2"), + AIMessage(content="done"), + ] + + middleware.before_model({"messages": messages}, _runtime()) + + preserved = captured[0].preserved_messages + summarized = captured[0].messages_to_summarize + + preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls) + summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls) + + assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"] + assert [tc["id"] for tc in summarized_ai.tool_calls] == ["skill-2"] + assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved) + assert not any(isinstance(m, ToolMessage) and getattr(m, "tool_call_id", None) == "skill-2" for m in preserved) + + def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None: queue = MagicMock() monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True)) diff --git a/config.example.yaml b/config.example.yaml index 1e649bba9..1c5bf4129 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -12,7 +12,7 @@ # ============================================================================ # Bump this number when the config schema changes. # Run `make config-upgrade` to merge new fields into your local config.yaml. -config_version: 7 +config_version: 8 # ============================================================================ # Logging @@ -726,6 +726,19 @@ summarization: # The prompt should guide the model to extract important context summary_prompt: null + # Recently-loaded skill files are excluded from summarization so the agent + # does not lose skill instructions after a compression pass. Claude Code uses + # a similar strategy (keep the most recent ~5 skills, ~25k total tokens, with + # a ~5k cap per skill). Set preserve_recent_skill_count to 0 to disable. + preserve_recent_skill_count: 5 + preserve_recent_skill_tokens: 25000 + preserve_recent_skill_tokens_per_skill: 5000 + skill_file_read_tool_names: + - read_file + - read + - view + - cat + # ============================================================================ # Memory Configuration # ============================================================================