fix(chat): preserve messages after summarization (#3280)

* fix(chat): preserve messages after summarization

* make format

* fix(chat): address summarization review comments
This commit is contained in:
Nan Gao 2026-05-29 02:24:47 +02:00 committed by GitHub
parent 2ace78d1e5
commit d46a5779bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 229 additions and 12 deletions

View File

@ -476,6 +476,24 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
def test_create_summarization_middleware_uses_frontend_supported_update_key(monkeypatch):
"""LangGraph update keys use the middleware class name plus hook name."""
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True)
app_config.memory = MemoryConfig(enabled=False)
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: fake_model)
middleware = lead_agent_module._create_summarization_middleware(app_config=app_config)
assert middleware is not None
update_key = f"{type(middleware).__name__}.before_model"
assert update_key == "DeerFlowSummarizationMiddleware.before_model"
def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")

View File

@ -5,7 +5,10 @@ from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
@ -22,6 +25,23 @@ def _messages() -> list:
]
class _StaticChatModel(BaseChatModel):
text: str = "ok"
@property
def _llm_type(self) -> str:
return "static-test-chat-model"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=self.text))])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
return HumanMessage(
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
@ -114,6 +134,32 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
assert result["messages"][1].content.startswith("Here is a summary")
def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> None:
middleware = DeerFlowSummarizationMiddleware(
model=_StaticChatModel(text="compressed summary"),
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
agent = create_agent(
model=_StaticChatModel(text="done"),
tools=[],
middleware=[middleware],
)
chunks = list(agent.stream({"messages": _messages()}, stream_mode="updates"))
update = next(
(chunk["DeerFlowSummarizationMiddleware.before_model"] for chunk in chunks if "DeerFlowSummarizationMiddleware.before_model" in chunk),
None,
)
assert update is not None
emitted = update["messages"]
assert isinstance(emitted[0], RemoveMessage)
assert emitted[1].name == "summary"
assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary")
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])

View File

@ -16,6 +16,7 @@ import { getAPIClient } from "../api";
import { fetch } from "../api/fetcher";
import { getBackendBaseURL } from "../config";
import { useI18n } from "../i18n/hooks";
import { isHiddenFromUIMessage } from "../messages/utils";
import type { FileInMessage } from "../messages/utils";
import type { LocalSettings } from "../settings";
import { useUpdateSubtask } from "../tasks/context";
@ -54,6 +55,11 @@ function isNonEmptyString(value: string | undefined): value is string {
return typeof value === "string" && value.length > 0;
}
const SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS = new Set([
"SummarizationMiddleware.before_model",
"DeerFlowSummarizationMiddleware.before_model",
]);
function messageIdentity(message: Message): string | undefined {
if (
"tool_call_id" in message &&
@ -70,17 +76,33 @@ function messageIdentity(message: Message): string | undefined {
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
const lastIndexByIdentity = new Map<string, number>();
const lastVisibleIndexByIdentity = new Map<string, number>();
// This is a UI-display dedupe rule, not a general LangChain message-stream
// contract. Hidden messages that share an identity with a visible message are
// treated as control messages for this merged view; hidden messages carrying
// independent tracing/task semantics should use a distinct id or a custom
// stream/state channel instead of relying on message dedupe preservation.
messages.forEach((message, index) => {
const identity = messageIdentity(message);
if (identity) {
lastIndexByIdentity.set(identity, index);
if (!isHiddenFromUIMessage(message)) {
lastVisibleIndexByIdentity.set(identity, index);
}
}
});
return messages.filter((message, index) => {
const identity = messageIdentity(message);
return !identity || lastIndexByIdentity.get(identity) === index;
if (!identity) {
return true;
}
const visibleIndex = lastVisibleIndexByIdentity.get(identity);
if (visibleIndex !== undefined) {
return visibleIndex === index;
}
return lastIndexByIdentity.get(identity) === index;
});
}
@ -102,8 +124,15 @@ export function mergeMessages(
threadMessages: Message[],
optimisticMessages: Message[],
): Message[] {
// Only visible live messages should trim overlapping history. Hidden messages
// are UI control messages in this path, not observability records; any hidden
// message that must survive as task/tracing data should use custom events or a
// separate state channel instead of participating in this overlap heuristic.
const threadMessageIds = new Set(
threadMessages.map(messageIdentity).filter(isNonEmptyString),
threadMessages
.filter((message) => !isHiddenFromUIMessage(message))
.map(messageIdentity)
.filter(isNonEmptyString),
);
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
@ -154,6 +183,30 @@ export function getVisibleOptimisticMessages(
return optimisticMessages;
}
export function getSummarizationMiddlewareMessages(
data: unknown,
): Message[] | undefined {
if (typeof data !== "object" || data === null) {
return undefined;
}
for (const [key, update] of Object.entries(data)) {
if (!SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS.has(key)) {
continue;
}
if (typeof update !== "object" || update === null) {
continue;
}
const messages = Reflect.get(update, "messages");
if (Array.isArray(messages)) {
return [...messages] as Message[];
}
}
return undefined;
}
export function upsertThreadInSearchCache(
queryClient: QueryClient,
thread: AgentThread,
@ -319,24 +372,25 @@ export function useThreadStream({
}
},
onUpdateEvent(data) {
if (data["SummarizationMiddleware.before_model"]) {
const _messages = [
...(data["SummarizationMiddleware.before_model"].messages ?? []),
];
if (_messages.length < 2) {
return;
}
const _messages = getSummarizationMiddlewareMessages(data);
if (_messages && _messages.length >= 2) {
for (const m of _messages) {
if (m.name === "summary" && m.type === "human") {
summarizedRef.current?.add(m.id ?? "");
}
}
const _lastKeepMessage = _messages[2];
const firstRetainedVisibleIdentity = _messages
.filter((message) => message.type !== "remove")
.filter((message) => !isHiddenFromUIMessage(message))
.map(messageIdentity)
.find(isNonEmptyString);
const _currentMessages = [...messagesRef.current];
const _movedMessages: Message[] = [];
for (const m of _currentMessages) {
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
if (
firstRetainedVisibleIdentity &&
messageIdentity(m) === firstRetainedVisibleIdentity
) {
break;
}
if (!summarizedRef.current?.has(m.id ?? "")) {

View File

@ -2,6 +2,7 @@ import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import {
getSummarizationMiddlewareMessages,
getVisibleOptimisticMessages,
mergeMessages,
} from "@/core/threads/hooks";
@ -66,6 +67,104 @@ test("mergeMessages deduplicates tool messages by tool_call_id", () => {
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
});
test("mergeMessages keeps a visible history message when a hidden live message reuses its id", () => {
const historyHuman = {
id: "human-1",
type: "human",
content: "visible user prompt",
} as Message;
const hiddenReminder = {
id: "human-1",
type: "human",
content: "<system-reminder>hidden</system-reminder>",
additional_kwargs: { hide_from_ui: true },
} as Message;
const liveAi = {
id: "ai-1",
type: "ai",
content: "live answer",
} as Message;
expect(mergeMessages([historyHuman], [hiddenReminder, liveAi], [])).toEqual([
historyHuman,
liveAi,
]);
});
test("mergeMessages lets a visible live message replace overlapping hidden history", () => {
const hiddenHistoryHuman = {
id: "human-1",
type: "human",
content: "<system-reminder>hidden</system-reminder>",
additional_kwargs: { hide_from_ui: true },
} as Message;
const liveHuman = {
id: "human-1",
type: "human",
content: "visible user prompt",
} as Message;
expect(mergeMessages([hiddenHistoryHuman], [liveHuman], [])).toEqual([
liveHuman,
]);
});
test("getSummarizationMiddlewareMessages matches DeerFlow summarization update keys", () => {
const removeAll = {
id: "__remove_all__",
type: "remove",
content: "",
} as Message;
const summary = {
id: "summary-1",
type: "human",
name: "summary",
content: "summary",
} as Message;
expect(
getSummarizationMiddlewareMessages({
"DeerFlowSummarizationMiddleware.before_model": {
messages: [removeAll, summary],
},
}),
).toEqual([removeAll, summary]);
});
test("getSummarizationMiddlewareMessages matches base LangChain summarization update keys", () => {
const summary = {
id: "summary-1",
type: "human",
name: "summary",
content: "summary",
} as Message;
expect(
getSummarizationMiddlewareMessages({
"SummarizationMiddleware.before_model": {
messages: [summary],
},
}),
).toEqual([summary]);
});
test("getSummarizationMiddlewareMessages ignores unrelated suffix-sharing update keys", () => {
const summary = {
id: "summary-1",
type: "human",
name: "summary",
content: "summary",
} as Message;
expect(
getSummarizationMiddlewareMessages({
"OtherSummarizationMiddleware.before_model": {
messages: [summary],
},
}),
).toBeUndefined();
});
test("getVisibleOptimisticMessages hides optimistic user input after server human arrives", () => {
const optimisticHuman = {
id: "opt-human-1",