"""Summarization middleware extensions for DeerFlow.""" from __future__ import annotations import logging from collections.abc import Collection from dataclasses import dataclass from typing import Any, Protocol, runtime_checkable from langchain.agents import AgentState from langchain.agents.middleware import SummarizationMiddleware from langchain_core.messages import AIMessage, AnyMessage, RemoveMessage, ToolMessage from langgraph.config import get_config from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.runtime import Runtime logger = logging.getLogger(__name__) @dataclass(frozen=True) class SummarizationEvent: """Context emitted before conversation history is summarized away.""" messages_to_summarize: tuple[AnyMessage, ...] preserved_messages: tuple[AnyMessage, ...] thread_id: str | None agent_name: str | None runtime: Runtime @runtime_checkable class BeforeSummarizationHook(Protocol): """Hook invoked before summarization removes messages from state.""" def __call__(self, event: SummarizationEvent) -> None: ... def _resolve_thread_id(runtime: Runtime) -> str | None: """Resolve the current thread ID from runtime context or LangGraph config.""" thread_id = runtime.context.get("thread_id") if runtime.context else None if thread_id is None: try: config_data = get_config() except RuntimeError: return None thread_id = config_data.get("configurable", {}).get("thread_id") return thread_id def _resolve_agent_name(runtime: Runtime) -> str | None: """Resolve the current agent name from runtime context or LangGraph config.""" agent_name = runtime.context.get("agent_name") if runtime.context else None if agent_name is None: try: config_data = get_config() except RuntimeError: return None agent_name = config_data.get("configurable", {}).get("agent_name") return agent_name def _tool_call_path(tool_call: dict[str, Any]) -> str | None: """Best-effort extraction of a file path argument from a read_file-like tool call.""" args = tool_call.get("args") or {} if not isinstance(args, dict): return None for key in ("path", "file_path", "filepath"): value = args.get(key) if isinstance(value, str) and value: return value return None def _clone_ai_message( message: AIMessage, tool_calls: list[dict[str, Any]], *, content: Any | None = None, ) -> AIMessage: """Clone an AIMessage while replacing its tool_calls list and optional content.""" update: dict[str, Any] = {"tool_calls": tool_calls} if content is not None: update["content"] = content return message.model_copy(update=update) @dataclass class _SkillBundle: """Skill-related tool calls and tool results associated with one AIMessage.""" ai_index: int skill_tool_indices: tuple[int, ...] skill_tool_call_ids: frozenset[str] skill_tool_tokens: int skill_key: str class DeerFlowSummarizationMiddleware(SummarizationMiddleware): """Summarization middleware with pre-compression hook dispatch and skill rescue.""" def __init__( self, *args, skills_container_path: str | None = None, skill_file_read_tool_names: Collection[str] | None = None, before_summarization: list[BeforeSummarizationHook] | None = None, preserve_recent_skill_count: int = 5, preserve_recent_skill_tokens: int = 25_000, preserve_recent_skill_tokens_per_skill: int = 5_000, **kwargs, ) -> None: super().__init__(*args, **kwargs) self._skills_container_path = skills_container_path or "/mnt/skills" self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"}) self._before_summarization_hooks = before_summarization or [] self._preserve_recent_skill_count = max(0, preserve_recent_skill_count) self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens) self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill) def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._maybe_summarize(state, runtime) async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None: return await self._amaybe_summarize(state, runtime) def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None: messages = state["messages"] self._ensure_message_ids(messages) total_tokens = self.token_counter(messages) if not self._should_summarize(messages, total_tokens): return None cutoff_index = self._determine_cutoff_index(messages) if cutoff_index <= 0: return None messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index) self._fire_hooks(messages_to_summarize, preserved_messages, runtime) summary = self._create_summary(messages_to_summarize) new_messages = self._build_new_messages(summary) return { "messages": [ RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages, *preserved_messages, ] } async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None: messages = state["messages"] self._ensure_message_ids(messages) total_tokens = self.token_counter(messages) if not self._should_summarize(messages, total_tokens): return None cutoff_index = self._determine_cutoff_index(messages) if cutoff_index <= 0: return None messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index) self._fire_hooks(messages_to_summarize, preserved_messages, runtime) summary = await self._acreate_summary(messages_to_summarize) new_messages = self._build_new_messages(summary) return { "messages": [ RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages, *preserved_messages, ] } def _partition_with_skill_rescue( self, messages: list[AnyMessage], cutoff_index: int, ) -> tuple[list[AnyMessage], list[AnyMessage]]: """Partition like the parent, then rescue recently-loaded skill bundles.""" to_summarize, preserved = self._partition_messages(messages, cutoff_index) if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize: return to_summarize, preserved try: bundles = self._find_skill_bundles(to_summarize, self._skills_container_path) except Exception: logger.exception("Skill-preserving summarization rescue failed; falling back to default partition") return to_summarize, preserved if not bundles: return to_summarize, preserved rescue_bundles = self._select_bundles_to_rescue(bundles) if not rescue_bundles: return to_summarize, preserved bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles} rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices} rescued: list[AnyMessage] = [] remaining: list[AnyMessage] = [] for i, msg in enumerate(to_summarize): bundle = bundles_by_ai_index.get(i) if bundle is not None and isinstance(msg, AIMessage): rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids] remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids] if rescued_tool_calls: rescued.append(_clone_ai_message(msg, rescued_tool_calls, content="")) if remaining_tool_calls or msg.content: remaining.append(_clone_ai_message(msg, remaining_tool_calls)) continue if i in rescue_tool_indices: rescued.append(msg) continue remaining.append(msg) return remaining, rescued + preserved def _find_skill_bundles( self, messages: list[AnyMessage], skills_root: str, ) -> list[_SkillBundle]: """Locate AIMessage + paired ToolMessage groups that load skill files.""" bundles: list[_SkillBundle] = [] n = len(messages) i = 0 while i < n: msg = messages[i] if not (isinstance(msg, AIMessage) and msg.tool_calls): i += 1 continue tool_calls = list(msg.tool_calls) skill_paths_by_id: dict[str, str] = {} for tc in tool_calls: if self._is_skill_tool_call(tc, skills_root): tc_id = tc.get("id") path = _tool_call_path(tc) if tc_id and path: skill_paths_by_id[tc_id] = path if not skill_paths_by_id: i += 1 continue skill_tool_tokens = 0 skill_key_parts: list[str] = [] skill_tool_indices: list[int] = [] matched_skill_call_ids: set[str] = set() j = i + 1 while j < n and isinstance(messages[j], ToolMessage): j += 1 for k in range(i + 1, j): tool_msg = messages[k] if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id: skill_tool_tokens += self.token_counter([tool_msg]) skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id]) skill_tool_indices.append(k) matched_skill_call_ids.add(tool_msg.tool_call_id) if not skill_tool_indices: i = j continue bundles.append( _SkillBundle( ai_index=i, skill_tool_indices=tuple(skill_tool_indices), skill_tool_call_ids=frozenset(matched_skill_call_ids), skill_tool_tokens=skill_tool_tokens, skill_key="|".join(sorted(skill_key_parts)), ) ) i = j return bundles def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]: """Pick bundles to keep, walking newest-first under count/token budgets.""" selected: list[_SkillBundle] = [] if not bundles: return selected seen_skill_keys: set[str] = set() total_tokens = 0 kept = 0 for bundle in reversed(bundles): if kept >= self._preserve_recent_skill_count: break if bundle.skill_key in seen_skill_keys: continue if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill: continue if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens: continue selected.append(bundle) total_tokens += bundle.skill_tool_tokens kept += 1 seen_skill_keys.add(bundle.skill_key) selected.reverse() return selected def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool: """Return True when ``tool_call`` reads a file under the configured skills root.""" name = tool_call.get("name") or "" if name not in self._skill_file_read_tool_names: return False path = _tool_call_path(tool_call) if not path: return False normalized_root = skills_root.rstrip("/") return path == normalized_root or path.startswith(normalized_root + "/") def _fire_hooks( self, messages_to_summarize: list[AnyMessage], preserved_messages: list[AnyMessage], runtime: Runtime, ) -> None: if not self._before_summarization_hooks: return event = SummarizationEvent( messages_to_summarize=tuple(messages_to_summarize), preserved_messages=tuple(preserved_messages), thread_id=_resolve_thread_id(runtime), agent_name=_resolve_agent_name(runtime), runtime=runtime, ) for hook in self._before_summarization_hooks: try: hook(event) except Exception: hook_name = getattr(hook, "__name__", None) or type(hook).__name__ logger.exception("before_summarization hook %s failed", hook_name)