diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index a12a754c2..0522b9ae1 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -476,6 +476,24 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"]) +def test_create_summarization_middleware_uses_frontend_supported_update_key(monkeypatch): + """LangGraph update keys use the middleware class name plus hook name.""" + + app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)]) + app_config.summarization = SummarizationConfig(enabled=True) + app_config.memory = MemoryConfig(enabled=False) + + fake_model = MagicMock() + fake_model.with_config.return_value = fake_model + monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: fake_model) + + middleware = lead_agent_module._create_summarization_middleware(app_config=app_config) + + assert middleware is not None + update_key = f"{type(middleware).__name__}.before_model" + assert update_key == "DeerFlowSummarizationMiddleware.before_model" + + def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch): fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)]) fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model") diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py index 9cd4fc725..ac702f470 100644 --- a/backend/tests/test_summarization_middleware.py +++ b/backend/tests/test_summarization_middleware.py @@ -5,7 +5,10 @@ from unittest import mock from unittest.mock import MagicMock import pytest +from langchain.agents import create_agent +from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage +from langchain_core.outputs import ChatGeneration, ChatResult from deerflow.agents.memory.summarization_hook import memory_flush_hook from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware @@ -22,6 +25,23 @@ def _messages() -> list: ] +class _StaticChatModel(BaseChatModel): + text: str = "ok" + + @property + def _llm_type(self) -> str: + return "static-test-chat-model" + + def bind_tools(self, tools, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content=self.text))]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage: return HumanMessage( content="\n2026-05-08, Friday\n", @@ -114,6 +134,32 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non assert result["messages"][1].content.startswith("Here is a summary") +def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> None: + middleware = DeerFlowSummarizationMiddleware( + model=_StaticChatModel(text="compressed summary"), + trigger=("messages", 4), + keep=("messages", 2), + token_counter=len, + ) + agent = create_agent( + model=_StaticChatModel(text="done"), + tools=[], + middleware=[middleware], + ) + + chunks = list(agent.stream({"messages": _messages()}, stream_mode="updates")) + update = next( + (chunk["DeerFlowSummarizationMiddleware.before_model"] for chunk in chunks if "DeerFlowSummarizationMiddleware.before_model" in chunk), + None, + ) + + assert update is not None + emitted = update["messages"] + assert isinstance(emitted[0], RemoveMessage) + assert emitted[1].name == "summary" + assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary") + + def test_dynamic_context_reminder_is_preserved_across_summarization() -> None: captured: list[SummarizationEvent] = [] middleware = _middleware(before_summarization=[captured.append]) diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index aa0e4dfbd..a4fd93d77 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -16,6 +16,7 @@ import { getAPIClient } from "../api"; import { fetch } from "../api/fetcher"; import { getBackendBaseURL } from "../config"; import { useI18n } from "../i18n/hooks"; +import { isHiddenFromUIMessage } from "../messages/utils"; import type { FileInMessage } from "../messages/utils"; import type { LocalSettings } from "../settings"; import { useUpdateSubtask } from "../tasks/context"; @@ -54,6 +55,11 @@ function isNonEmptyString(value: string | undefined): value is string { return typeof value === "string" && value.length > 0; } +const SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS = new Set([ + "SummarizationMiddleware.before_model", + "DeerFlowSummarizationMiddleware.before_model", +]); + function messageIdentity(message: Message): string | undefined { if ( "tool_call_id" in message && @@ -70,17 +76,33 @@ function messageIdentity(message: Message): string | undefined { function dedupeMessagesByIdentity(messages: Message[]): Message[] { const lastIndexByIdentity = new Map(); + const lastVisibleIndexByIdentity = new Map(); + // This is a UI-display dedupe rule, not a general LangChain message-stream + // contract. Hidden messages that share an identity with a visible message are + // treated as control messages for this merged view; hidden messages carrying + // independent tracing/task semantics should use a distinct id or a custom + // stream/state channel instead of relying on message dedupe preservation. messages.forEach((message, index) => { const identity = messageIdentity(message); if (identity) { lastIndexByIdentity.set(identity, index); + if (!isHiddenFromUIMessage(message)) { + lastVisibleIndexByIdentity.set(identity, index); + } } }); return messages.filter((message, index) => { const identity = messageIdentity(message); - return !identity || lastIndexByIdentity.get(identity) === index; + if (!identity) { + return true; + } + const visibleIndex = lastVisibleIndexByIdentity.get(identity); + if (visibleIndex !== undefined) { + return visibleIndex === index; + } + return lastIndexByIdentity.get(identity) === index; }); } @@ -102,8 +124,15 @@ export function mergeMessages( threadMessages: Message[], optimisticMessages: Message[], ): Message[] { + // Only visible live messages should trim overlapping history. Hidden messages + // are UI control messages in this path, not observability records; any hidden + // message that must survive as task/tracing data should use custom events or a + // separate state channel instead of participating in this overlap heuristic. const threadMessageIds = new Set( - threadMessages.map(messageIdentity).filter(isNonEmptyString), + threadMessages + .filter((message) => !isHiddenFromUIMessage(message)) + .map(messageIdentity) + .filter(isNonEmptyString), ); // The overlap is a contiguous suffix of historyMessages (newest history == oldest thread). @@ -154,6 +183,30 @@ export function getVisibleOptimisticMessages( return optimisticMessages; } +export function getSummarizationMiddlewareMessages( + data: unknown, +): Message[] | undefined { + if (typeof data !== "object" || data === null) { + return undefined; + } + + for (const [key, update] of Object.entries(data)) { + if (!SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS.has(key)) { + continue; + } + if (typeof update !== "object" || update === null) { + continue; + } + + const messages = Reflect.get(update, "messages"); + if (Array.isArray(messages)) { + return [...messages] as Message[]; + } + } + + return undefined; +} + export function upsertThreadInSearchCache( queryClient: QueryClient, thread: AgentThread, @@ -319,24 +372,25 @@ export function useThreadStream({ } }, onUpdateEvent(data) { - if (data["SummarizationMiddleware.before_model"]) { - const _messages = [ - ...(data["SummarizationMiddleware.before_model"].messages ?? []), - ]; - - if (_messages.length < 2) { - return; - } + const _messages = getSummarizationMiddlewareMessages(data); + if (_messages && _messages.length >= 2) { for (const m of _messages) { if (m.name === "summary" && m.type === "human") { summarizedRef.current?.add(m.id ?? ""); } } - const _lastKeepMessage = _messages[2]; + const firstRetainedVisibleIdentity = _messages + .filter((message) => message.type !== "remove") + .filter((message) => !isHiddenFromUIMessage(message)) + .map(messageIdentity) + .find(isNonEmptyString); const _currentMessages = [...messagesRef.current]; const _movedMessages: Message[] = []; for (const m of _currentMessages) { - if (m.id !== undefined && m.id === _lastKeepMessage?.id) { + if ( + firstRetainedVisibleIdentity && + messageIdentity(m) === firstRetainedVisibleIdentity + ) { break; } if (!summarizedRef.current?.has(m.id ?? "")) { diff --git a/frontend/tests/unit/core/threads/message-merge.test.ts b/frontend/tests/unit/core/threads/message-merge.test.ts index 2afca1eef..a6e612bfd 100644 --- a/frontend/tests/unit/core/threads/message-merge.test.ts +++ b/frontend/tests/unit/core/threads/message-merge.test.ts @@ -2,6 +2,7 @@ import type { Message } from "@langchain/langgraph-sdk"; import { expect, test } from "vitest"; import { + getSummarizationMiddlewareMessages, getVisibleOptimisticMessages, mergeMessages, } from "@/core/threads/hooks"; @@ -66,6 +67,104 @@ test("mergeMessages deduplicates tool messages by tool_call_id", () => { expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]); }); +test("mergeMessages keeps a visible history message when a hidden live message reuses its id", () => { + const historyHuman = { + id: "human-1", + type: "human", + content: "visible user prompt", + } as Message; + const hiddenReminder = { + id: "human-1", + type: "human", + content: "hidden", + additional_kwargs: { hide_from_ui: true }, + } as Message; + const liveAi = { + id: "ai-1", + type: "ai", + content: "live answer", + } as Message; + + expect(mergeMessages([historyHuman], [hiddenReminder, liveAi], [])).toEqual([ + historyHuman, + liveAi, + ]); +}); + +test("mergeMessages lets a visible live message replace overlapping hidden history", () => { + const hiddenHistoryHuman = { + id: "human-1", + type: "human", + content: "hidden", + additional_kwargs: { hide_from_ui: true }, + } as Message; + const liveHuman = { + id: "human-1", + type: "human", + content: "visible user prompt", + } as Message; + + expect(mergeMessages([hiddenHistoryHuman], [liveHuman], [])).toEqual([ + liveHuman, + ]); +}); + +test("getSummarizationMiddlewareMessages matches DeerFlow summarization update keys", () => { + const removeAll = { + id: "__remove_all__", + type: "remove", + content: "", + } as Message; + const summary = { + id: "summary-1", + type: "human", + name: "summary", + content: "summary", + } as Message; + + expect( + getSummarizationMiddlewareMessages({ + "DeerFlowSummarizationMiddleware.before_model": { + messages: [removeAll, summary], + }, + }), + ).toEqual([removeAll, summary]); +}); + +test("getSummarizationMiddlewareMessages matches base LangChain summarization update keys", () => { + const summary = { + id: "summary-1", + type: "human", + name: "summary", + content: "summary", + } as Message; + + expect( + getSummarizationMiddlewareMessages({ + "SummarizationMiddleware.before_model": { + messages: [summary], + }, + }), + ).toEqual([summary]); +}); + +test("getSummarizationMiddlewareMessages ignores unrelated suffix-sharing update keys", () => { + const summary = { + id: "summary-1", + type: "human", + name: "summary", + content: "summary", + } as Message; + + expect( + getSummarizationMiddlewareMessages({ + "OtherSummarizationMiddleware.before_model": { + messages: [summary], + }, + }), + ).toBeUndefined(); +}); + test("getVisibleOptimisticMessages hides optimistic user input after server human arrives", () => { const optimisticHuman = { id: "opt-human-1",