mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-30 12:28:10 +00:00
fix(chat): preserve messages after summarization (#3280)
* fix(chat): preserve messages after summarization * make format * fix(chat): address summarization review comments
This commit is contained in:
parent
2ace78d1e5
commit
d46a5779bc
@ -476,6 +476,24 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_frontend_supported_update_key(monkeypatch):
|
||||
"""LangGraph update keys use the middleware class name plus hook name."""
|
||||
|
||||
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
|
||||
app_config.summarization = SummarizationConfig(enabled=True)
|
||||
app_config.memory = MemoryConfig(enabled=False)
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware(app_config=app_config)
|
||||
|
||||
assert middleware is not None
|
||||
update_key = f"{type(middleware).__name__}.before_model"
|
||||
assert update_key == "DeerFlowSummarizationMiddleware.before_model"
|
||||
|
||||
|
||||
def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
|
||||
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
|
||||
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
|
||||
|
||||
@ -5,7 +5,10 @@ from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
|
||||
@ -22,6 +25,23 @@ def _messages() -> list:
|
||||
]
|
||||
|
||||
|
||||
class _StaticChatModel(BaseChatModel):
|
||||
text: str = "ok"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "static-test-chat-model"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=self.text))])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
||||
@ -114,6 +134,32 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
|
||||
assert result["messages"][1].content.startswith("Here is a summary")
|
||||
|
||||
|
||||
def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> None:
|
||||
middleware = DeerFlowSummarizationMiddleware(
|
||||
model=_StaticChatModel(text="compressed summary"),
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
token_counter=len,
|
||||
)
|
||||
agent = create_agent(
|
||||
model=_StaticChatModel(text="done"),
|
||||
tools=[],
|
||||
middleware=[middleware],
|
||||
)
|
||||
|
||||
chunks = list(agent.stream({"messages": _messages()}, stream_mode="updates"))
|
||||
update = next(
|
||||
(chunk["DeerFlowSummarizationMiddleware.before_model"] for chunk in chunks if "DeerFlowSummarizationMiddleware.before_model" in chunk),
|
||||
None,
|
||||
)
|
||||
|
||||
assert update is not None
|
||||
emitted = update["messages"]
|
||||
assert isinstance(emitted[0], RemoveMessage)
|
||||
assert emitted[1].name == "summary"
|
||||
assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary")
|
||||
|
||||
|
||||
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
|
||||
@ -16,6 +16,7 @@ import { getAPIClient } from "../api";
|
||||
import { fetch } from "../api/fetcher";
|
||||
import { getBackendBaseURL } from "../config";
|
||||
import { useI18n } from "../i18n/hooks";
|
||||
import { isHiddenFromUIMessage } from "../messages/utils";
|
||||
import type { FileInMessage } from "../messages/utils";
|
||||
import type { LocalSettings } from "../settings";
|
||||
import { useUpdateSubtask } from "../tasks/context";
|
||||
@ -54,6 +55,11 @@ function isNonEmptyString(value: string | undefined): value is string {
|
||||
return typeof value === "string" && value.length > 0;
|
||||
}
|
||||
|
||||
const SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS = new Set([
|
||||
"SummarizationMiddleware.before_model",
|
||||
"DeerFlowSummarizationMiddleware.before_model",
|
||||
]);
|
||||
|
||||
function messageIdentity(message: Message): string | undefined {
|
||||
if (
|
||||
"tool_call_id" in message &&
|
||||
@ -70,17 +76,33 @@ function messageIdentity(message: Message): string | undefined {
|
||||
|
||||
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
|
||||
const lastIndexByIdentity = new Map<string, number>();
|
||||
const lastVisibleIndexByIdentity = new Map<string, number>();
|
||||
|
||||
// This is a UI-display dedupe rule, not a general LangChain message-stream
|
||||
// contract. Hidden messages that share an identity with a visible message are
|
||||
// treated as control messages for this merged view; hidden messages carrying
|
||||
// independent tracing/task semantics should use a distinct id or a custom
|
||||
// stream/state channel instead of relying on message dedupe preservation.
|
||||
messages.forEach((message, index) => {
|
||||
const identity = messageIdentity(message);
|
||||
if (identity) {
|
||||
lastIndexByIdentity.set(identity, index);
|
||||
if (!isHiddenFromUIMessage(message)) {
|
||||
lastVisibleIndexByIdentity.set(identity, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return messages.filter((message, index) => {
|
||||
const identity = messageIdentity(message);
|
||||
return !identity || lastIndexByIdentity.get(identity) === index;
|
||||
if (!identity) {
|
||||
return true;
|
||||
}
|
||||
const visibleIndex = lastVisibleIndexByIdentity.get(identity);
|
||||
if (visibleIndex !== undefined) {
|
||||
return visibleIndex === index;
|
||||
}
|
||||
return lastIndexByIdentity.get(identity) === index;
|
||||
});
|
||||
}
|
||||
|
||||
@ -102,8 +124,15 @@ export function mergeMessages(
|
||||
threadMessages: Message[],
|
||||
optimisticMessages: Message[],
|
||||
): Message[] {
|
||||
// Only visible live messages should trim overlapping history. Hidden messages
|
||||
// are UI control messages in this path, not observability records; any hidden
|
||||
// message that must survive as task/tracing data should use custom events or a
|
||||
// separate state channel instead of participating in this overlap heuristic.
|
||||
const threadMessageIds = new Set(
|
||||
threadMessages.map(messageIdentity).filter(isNonEmptyString),
|
||||
threadMessages
|
||||
.filter((message) => !isHiddenFromUIMessage(message))
|
||||
.map(messageIdentity)
|
||||
.filter(isNonEmptyString),
|
||||
);
|
||||
|
||||
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
||||
@ -154,6 +183,30 @@ export function getVisibleOptimisticMessages(
|
||||
return optimisticMessages;
|
||||
}
|
||||
|
||||
export function getSummarizationMiddlewareMessages(
|
||||
data: unknown,
|
||||
): Message[] | undefined {
|
||||
if (typeof data !== "object" || data === null) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
for (const [key, update] of Object.entries(data)) {
|
||||
if (!SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS.has(key)) {
|
||||
continue;
|
||||
}
|
||||
if (typeof update !== "object" || update === null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const messages = Reflect.get(update, "messages");
|
||||
if (Array.isArray(messages)) {
|
||||
return [...messages] as Message[];
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function upsertThreadInSearchCache(
|
||||
queryClient: QueryClient,
|
||||
thread: AgentThread,
|
||||
@ -319,24 +372,25 @@ export function useThreadStream({
|
||||
}
|
||||
},
|
||||
onUpdateEvent(data) {
|
||||
if (data["SummarizationMiddleware.before_model"]) {
|
||||
const _messages = [
|
||||
...(data["SummarizationMiddleware.before_model"].messages ?? []),
|
||||
];
|
||||
|
||||
if (_messages.length < 2) {
|
||||
return;
|
||||
}
|
||||
const _messages = getSummarizationMiddlewareMessages(data);
|
||||
if (_messages && _messages.length >= 2) {
|
||||
for (const m of _messages) {
|
||||
if (m.name === "summary" && m.type === "human") {
|
||||
summarizedRef.current?.add(m.id ?? "");
|
||||
}
|
||||
}
|
||||
const _lastKeepMessage = _messages[2];
|
||||
const firstRetainedVisibleIdentity = _messages
|
||||
.filter((message) => message.type !== "remove")
|
||||
.filter((message) => !isHiddenFromUIMessage(message))
|
||||
.map(messageIdentity)
|
||||
.find(isNonEmptyString);
|
||||
const _currentMessages = [...messagesRef.current];
|
||||
const _movedMessages: Message[] = [];
|
||||
for (const m of _currentMessages) {
|
||||
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
|
||||
if (
|
||||
firstRetainedVisibleIdentity &&
|
||||
messageIdentity(m) === firstRetainedVisibleIdentity
|
||||
) {
|
||||
break;
|
||||
}
|
||||
if (!summarizedRef.current?.has(m.id ?? "")) {
|
||||
|
||||
@ -2,6 +2,7 @@ import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { expect, test } from "vitest";
|
||||
|
||||
import {
|
||||
getSummarizationMiddlewareMessages,
|
||||
getVisibleOptimisticMessages,
|
||||
mergeMessages,
|
||||
} from "@/core/threads/hooks";
|
||||
@ -66,6 +67,104 @@ test("mergeMessages deduplicates tool messages by tool_call_id", () => {
|
||||
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
|
||||
});
|
||||
|
||||
test("mergeMessages keeps a visible history message when a hidden live message reuses its id", () => {
|
||||
const historyHuman = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "visible user prompt",
|
||||
} as Message;
|
||||
const hiddenReminder = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "<system-reminder>hidden</system-reminder>",
|
||||
additional_kwargs: { hide_from_ui: true },
|
||||
} as Message;
|
||||
const liveAi = {
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "live answer",
|
||||
} as Message;
|
||||
|
||||
expect(mergeMessages([historyHuman], [hiddenReminder, liveAi], [])).toEqual([
|
||||
historyHuman,
|
||||
liveAi,
|
||||
]);
|
||||
});
|
||||
|
||||
test("mergeMessages lets a visible live message replace overlapping hidden history", () => {
|
||||
const hiddenHistoryHuman = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "<system-reminder>hidden</system-reminder>",
|
||||
additional_kwargs: { hide_from_ui: true },
|
||||
} as Message;
|
||||
const liveHuman = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "visible user prompt",
|
||||
} as Message;
|
||||
|
||||
expect(mergeMessages([hiddenHistoryHuman], [liveHuman], [])).toEqual([
|
||||
liveHuman,
|
||||
]);
|
||||
});
|
||||
|
||||
test("getSummarizationMiddlewareMessages matches DeerFlow summarization update keys", () => {
|
||||
const removeAll = {
|
||||
id: "__remove_all__",
|
||||
type: "remove",
|
||||
content: "",
|
||||
} as Message;
|
||||
const summary = {
|
||||
id: "summary-1",
|
||||
type: "human",
|
||||
name: "summary",
|
||||
content: "summary",
|
||||
} as Message;
|
||||
|
||||
expect(
|
||||
getSummarizationMiddlewareMessages({
|
||||
"DeerFlowSummarizationMiddleware.before_model": {
|
||||
messages: [removeAll, summary],
|
||||
},
|
||||
}),
|
||||
).toEqual([removeAll, summary]);
|
||||
});
|
||||
|
||||
test("getSummarizationMiddlewareMessages matches base LangChain summarization update keys", () => {
|
||||
const summary = {
|
||||
id: "summary-1",
|
||||
type: "human",
|
||||
name: "summary",
|
||||
content: "summary",
|
||||
} as Message;
|
||||
|
||||
expect(
|
||||
getSummarizationMiddlewareMessages({
|
||||
"SummarizationMiddleware.before_model": {
|
||||
messages: [summary],
|
||||
},
|
||||
}),
|
||||
).toEqual([summary]);
|
||||
});
|
||||
|
||||
test("getSummarizationMiddlewareMessages ignores unrelated suffix-sharing update keys", () => {
|
||||
const summary = {
|
||||
id: "summary-1",
|
||||
type: "human",
|
||||
name: "summary",
|
||||
content: "summary",
|
||||
} as Message;
|
||||
|
||||
expect(
|
||||
getSummarizationMiddlewareMessages({
|
||||
"OtherSummarizationMiddleware.before_model": {
|
||||
messages: [summary],
|
||||
},
|
||||
}),
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
test("getVisibleOptimisticMessages hides optimistic user input after server human arrives", () => {
|
||||
const optimisticHuman = {
|
||||
id: "opt-human-1",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user