mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
* fix(middelware): narrow skill rescue to skill-related tool outputs * fix(summarization): address skill rescue review feedback * fix: wire summarization skill rescue config * fix: remove dead skill tool helper * fix(lint): fix format --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
348 lines
13 KiB
Python
348 lines
13 KiB
Python
"""Summarization middleware extensions for DeerFlow."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import Collection
|
|
from dataclasses import dataclass
|
|
from typing import Any, Protocol, runtime_checkable
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import SummarizationMiddleware
|
|
from langchain_core.messages import AIMessage, AnyMessage, RemoveMessage, ToolMessage
|
|
from langgraph.config import get_config
|
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
|
from langgraph.runtime import Runtime
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SummarizationEvent:
|
|
"""Context emitted before conversation history is summarized away."""
|
|
|
|
messages_to_summarize: tuple[AnyMessage, ...]
|
|
preserved_messages: tuple[AnyMessage, ...]
|
|
thread_id: str | None
|
|
agent_name: str | None
|
|
runtime: Runtime
|
|
|
|
|
|
@runtime_checkable
|
|
class BeforeSummarizationHook(Protocol):
|
|
"""Hook invoked before summarization removes messages from state."""
|
|
|
|
def __call__(self, event: SummarizationEvent) -> None: ...
|
|
|
|
|
|
def _resolve_thread_id(runtime: Runtime) -> str | None:
|
|
"""Resolve the current thread ID from runtime context or LangGraph config."""
|
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
if thread_id is None:
|
|
try:
|
|
config_data = get_config()
|
|
except RuntimeError:
|
|
return None
|
|
thread_id = config_data.get("configurable", {}).get("thread_id")
|
|
return thread_id
|
|
|
|
|
|
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
|
"""Resolve the current agent name from runtime context or LangGraph config."""
|
|
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
|
if agent_name is None:
|
|
try:
|
|
config_data = get_config()
|
|
except RuntimeError:
|
|
return None
|
|
agent_name = config_data.get("configurable", {}).get("agent_name")
|
|
return agent_name
|
|
|
|
|
|
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
|
|
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
|
|
args = tool_call.get("args") or {}
|
|
if not isinstance(args, dict):
|
|
return None
|
|
for key in ("path", "file_path", "filepath"):
|
|
value = args.get(key)
|
|
if isinstance(value, str) and value:
|
|
return value
|
|
return None
|
|
|
|
|
|
def _clone_ai_message(
|
|
message: AIMessage,
|
|
tool_calls: list[dict[str, Any]],
|
|
*,
|
|
content: Any | None = None,
|
|
) -> AIMessage:
|
|
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
|
update: dict[str, Any] = {"tool_calls": tool_calls}
|
|
if content is not None:
|
|
update["content"] = content
|
|
return message.model_copy(update=update)
|
|
|
|
|
|
@dataclass
|
|
class _SkillBundle:
|
|
"""Skill-related tool calls and tool results associated with one AIMessage."""
|
|
|
|
ai_index: int
|
|
skill_tool_indices: tuple[int, ...]
|
|
skill_tool_call_ids: frozenset[str]
|
|
skill_tool_tokens: int
|
|
skill_key: str
|
|
|
|
|
|
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|
"""Summarization middleware with pre-compression hook dispatch and skill rescue."""
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
skills_container_path: str | None = None,
|
|
skill_file_read_tool_names: Collection[str] | None = None,
|
|
before_summarization: list[BeforeSummarizationHook] | None = None,
|
|
preserve_recent_skill_count: int = 5,
|
|
preserve_recent_skill_tokens: int = 25_000,
|
|
preserve_recent_skill_tokens_per_skill: int = 5_000,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self._skills_container_path = skills_container_path or "/mnt/skills"
|
|
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
|
|
self._before_summarization_hooks = before_summarization or []
|
|
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
|
|
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
|
|
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
|
|
|
|
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._maybe_summarize(state, runtime)
|
|
|
|
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return await self._amaybe_summarize(state, runtime)
|
|
|
|
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
messages = state["messages"]
|
|
self._ensure_message_ids(messages)
|
|
|
|
total_tokens = self.token_counter(messages)
|
|
if not self._should_summarize(messages, total_tokens):
|
|
return None
|
|
|
|
cutoff_index = self._determine_cutoff_index(messages)
|
|
if cutoff_index <= 0:
|
|
return None
|
|
|
|
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
|
summary = self._create_summary(messages_to_summarize)
|
|
new_messages = self._build_new_messages(summary)
|
|
|
|
return {
|
|
"messages": [
|
|
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
*new_messages,
|
|
*preserved_messages,
|
|
]
|
|
}
|
|
|
|
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
messages = state["messages"]
|
|
self._ensure_message_ids(messages)
|
|
|
|
total_tokens = self.token_counter(messages)
|
|
if not self._should_summarize(messages, total_tokens):
|
|
return None
|
|
|
|
cutoff_index = self._determine_cutoff_index(messages)
|
|
if cutoff_index <= 0:
|
|
return None
|
|
|
|
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
|
summary = await self._acreate_summary(messages_to_summarize)
|
|
new_messages = self._build_new_messages(summary)
|
|
|
|
return {
|
|
"messages": [
|
|
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
*new_messages,
|
|
*preserved_messages,
|
|
]
|
|
}
|
|
|
|
def _partition_with_skill_rescue(
|
|
self,
|
|
messages: list[AnyMessage],
|
|
cutoff_index: int,
|
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
"""Partition like the parent, then rescue recently-loaded skill bundles."""
|
|
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
|
|
|
|
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
|
|
return to_summarize, preserved
|
|
|
|
try:
|
|
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
|
|
except Exception:
|
|
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
|
|
return to_summarize, preserved
|
|
|
|
if not bundles:
|
|
return to_summarize, preserved
|
|
|
|
rescue_bundles = self._select_bundles_to_rescue(bundles)
|
|
if not rescue_bundles:
|
|
return to_summarize, preserved
|
|
|
|
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
|
|
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
|
|
rescued: list[AnyMessage] = []
|
|
remaining: list[AnyMessage] = []
|
|
for i, msg in enumerate(to_summarize):
|
|
bundle = bundles_by_ai_index.get(i)
|
|
if bundle is not None and isinstance(msg, AIMessage):
|
|
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
|
|
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
|
|
|
|
if rescued_tool_calls:
|
|
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
|
|
if remaining_tool_calls or msg.content:
|
|
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
|
|
continue
|
|
|
|
if i in rescue_tool_indices:
|
|
rescued.append(msg)
|
|
continue
|
|
|
|
remaining.append(msg)
|
|
|
|
return remaining, rescued + preserved
|
|
|
|
def _find_skill_bundles(
|
|
self,
|
|
messages: list[AnyMessage],
|
|
skills_root: str,
|
|
) -> list[_SkillBundle]:
|
|
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
|
|
bundles: list[_SkillBundle] = []
|
|
n = len(messages)
|
|
i = 0
|
|
while i < n:
|
|
msg = messages[i]
|
|
if not (isinstance(msg, AIMessage) and msg.tool_calls):
|
|
i += 1
|
|
continue
|
|
|
|
tool_calls = list(msg.tool_calls)
|
|
skill_paths_by_id: dict[str, str] = {}
|
|
for tc in tool_calls:
|
|
if self._is_skill_tool_call(tc, skills_root):
|
|
tc_id = tc.get("id")
|
|
path = _tool_call_path(tc)
|
|
if tc_id and path:
|
|
skill_paths_by_id[tc_id] = path
|
|
|
|
if not skill_paths_by_id:
|
|
i += 1
|
|
continue
|
|
|
|
skill_tool_tokens = 0
|
|
skill_key_parts: list[str] = []
|
|
skill_tool_indices: list[int] = []
|
|
matched_skill_call_ids: set[str] = set()
|
|
|
|
j = i + 1
|
|
while j < n and isinstance(messages[j], ToolMessage):
|
|
j += 1
|
|
|
|
for k in range(i + 1, j):
|
|
tool_msg = messages[k]
|
|
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
|
|
skill_tool_tokens += self.token_counter([tool_msg])
|
|
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
|
|
skill_tool_indices.append(k)
|
|
matched_skill_call_ids.add(tool_msg.tool_call_id)
|
|
|
|
if not skill_tool_indices:
|
|
i = j
|
|
continue
|
|
|
|
bundles.append(
|
|
_SkillBundle(
|
|
ai_index=i,
|
|
skill_tool_indices=tuple(skill_tool_indices),
|
|
skill_tool_call_ids=frozenset(matched_skill_call_ids),
|
|
skill_tool_tokens=skill_tool_tokens,
|
|
skill_key="|".join(sorted(skill_key_parts)),
|
|
)
|
|
)
|
|
i = j
|
|
|
|
return bundles
|
|
|
|
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
|
|
"""Pick bundles to keep, walking newest-first under count/token budgets."""
|
|
selected: list[_SkillBundle] = []
|
|
if not bundles:
|
|
return selected
|
|
|
|
seen_skill_keys: set[str] = set()
|
|
total_tokens = 0
|
|
kept = 0
|
|
|
|
for bundle in reversed(bundles):
|
|
if kept >= self._preserve_recent_skill_count:
|
|
break
|
|
if bundle.skill_key in seen_skill_keys:
|
|
continue
|
|
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
|
|
continue
|
|
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
|
|
continue
|
|
|
|
selected.append(bundle)
|
|
total_tokens += bundle.skill_tool_tokens
|
|
kept += 1
|
|
seen_skill_keys.add(bundle.skill_key)
|
|
|
|
selected.reverse()
|
|
return selected
|
|
|
|
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
|
|
"""Return True when ``tool_call`` reads a file under the configured skills root."""
|
|
name = tool_call.get("name") or ""
|
|
if name not in self._skill_file_read_tool_names:
|
|
return False
|
|
path = _tool_call_path(tool_call)
|
|
if not path:
|
|
return False
|
|
normalized_root = skills_root.rstrip("/")
|
|
return path == normalized_root or path.startswith(normalized_root + "/")
|
|
|
|
def _fire_hooks(
|
|
self,
|
|
messages_to_summarize: list[AnyMessage],
|
|
preserved_messages: list[AnyMessage],
|
|
runtime: Runtime,
|
|
) -> None:
|
|
if not self._before_summarization_hooks:
|
|
return
|
|
|
|
event = SummarizationEvent(
|
|
messages_to_summarize=tuple(messages_to_summarize),
|
|
preserved_messages=tuple(preserved_messages),
|
|
thread_id=_resolve_thread_id(runtime),
|
|
agent_name=_resolve_agent_name(runtime),
|
|
runtime=runtime,
|
|
)
|
|
|
|
for hook in self._before_summarization_hooks:
|
|
try:
|
|
hook(event)
|
|
except Exception:
|
|
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
|
|
logger.exception("before_summarization hook %s failed", hook_name)
|