diff --git a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py index af0881e88..917823612 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py @@ -9,8 +9,9 @@ from typing import Any, Protocol, override, runtime_checkable from langchain.agents import AgentState from langchain.agents.middleware import SummarizationMiddleware -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage +from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage, get_buffer_string from langgraph.config import get_config +from langgraph.constants import TAG_NOSTREAM from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.runtime import Runtime @@ -116,6 +117,74 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware): self._preserve_recent_skill_count = max(0, preserve_recent_skill_count) self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens) self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill) + # The summary LLM call runs inside a LangGraph middleware hook, so its token + # stream would otherwise be captured by the messages-tuple stream callback and + # broadcast to the frontend as a phantom AI message. Tag a dedicated model copy + # with TAG_NOSTREAM so the streaming handler skips it. + # Keep self.model untagged so the parent's profile / ls_params inspection still works. + # + # Preserve any tags already bound on the model (e.g. "middleware:summarize" set in + # lead_agent/agent.py for RunJournal attribution): RunnableBinding.with_config does a + # shallow merge that would otherwise overwrite the existing tags list entirely. + existing_tags = list((getattr(self.model, "config", None) or {}).get("tags") or []) + merged_tags = [*existing_tags, TAG_NOSTREAM] if TAG_NOSTREAM not in existing_tags else existing_tags + self._summary_model = self.model.with_config(tags=merged_tags) + + @override + def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str: + return self._summarize_with(messages_to_summarize) + + @override + async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str: + return await self._asummarize_with(messages_to_summarize) + + def _summarize_with(self, messages_to_summarize: list[AnyMessage]) -> str: + """Mirror the parent ``_create_summary`` but invoke the nostream-tagged model. + + We do not swap ``self.model`` at the instance level: the agent/middleware is + cached and reused across concurrent runs, so a temporary swap would leak the + ``RunnableBinding`` to other coroutines during ``await`` and break parent logic + that inspects the raw model (``profile`` / ``_get_ls_params``). + """ + if not messages_to_summarize: + return "No previous conversation history." + prompt = self._build_summary_prompt(messages_to_summarize) + if prompt is None: + return "Previous conversation was too long to summarize." + try: + response = self._summary_model.invoke( + prompt, + config={"metadata": {"lc_source": "summarization"}}, + ) + return response.text.strip() + except Exception as e: + return f"Error generating summary: {e!s}" + + async def _asummarize_with(self, messages_to_summarize: list[AnyMessage]) -> str: + """Async counterpart of :meth:`_summarize_with` using the nostream model.""" + if not messages_to_summarize: + return "No previous conversation history." + prompt = self._build_summary_prompt(messages_to_summarize) + if prompt is None: + return "Previous conversation was too long to summarize." + try: + response = await self._summary_model.ainvoke( + prompt, + config={"metadata": {"lc_source": "summarization"}}, + ) + return response.text.strip() + except Exception as e: + return f"Error generating summary: {e!s}" + + def _build_summary_prompt(self, messages_to_summarize: list[AnyMessage]) -> str | None: + """Build the summary prompt, returning ``None`` when trimming leaves nothing.""" + trimmed_messages = self._trim_messages_for_summary(messages_to_summarize) + if not trimmed_messages: + return None + # Format messages to avoid token inflation from metadata when str() is called on + # message objects. + formatted_messages = get_buffer_string(trimmed_messages) + return self.summary_prompt.format(messages=formatted_messages).rstrip() def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._maybe_summarize(state, runtime) diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py index ac702f470..d01c4479a 100644 --- a/backend/tests/test_summarization_middleware.py +++ b/backend/tests/test_summarization_middleware.py @@ -9,6 +9,7 @@ from langchain.agents import create_agent from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage from langchain_core.outputs import ChatGeneration, ChatResult +from langgraph.constants import TAG_NOSTREAM from deerflow.agents.memory.summarization_hook import memory_flush_hook from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware @@ -160,6 +161,120 @@ def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary") +def test_summary_model_is_tagged_nostream_to_avoid_stream_pollution() -> None: + tags_during_summary: list[list[str]] = [] + + class _RecordingChatModel(_StaticChatModel): + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + tags_during_summary.append(list(run_manager.tags) if run_manager else []) + return super()._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + model = _RecordingChatModel(text="compressed summary") + middleware = DeerFlowSummarizationMiddleware( + model=model, + trigger=("messages", 4), + keep=("messages", 2), + token_counter=len, + ) + + # The dedicated summary model must carry TAG_NOSTREAM so LangGraph's + # messages-tuple stream handler skips its tokens, while the raw model used by + # the parent for profile / token inspection stays untagged. + assert TAG_NOSTREAM in (middleware._summary_model.config.get("tags") or []) + assert TAG_NOSTREAM not in (getattr(middleware.model, "config", {}).get("tags") or []) + + result = middleware.before_model({"messages": _messages()}, _runtime()) + + # The summary LLM call must actually run with the nostream tag (this is what the + # stream handler inspects), and the shared self.model must remain the raw, + # untagged model so parent logic (profile / _get_ls_params) keeps working. + assert tags_during_summary == [[TAG_NOSTREAM]] + assert middleware.model is model + assert result["messages"][1].content.startswith("Here is a summary") + + +def test_summarization_does_not_mutate_shared_model_across_concurrent_runs() -> None: + """Concurrent runs must not observe a swapped-out self.model during summarization. + + The agent/middleware instance is cached and reused, so summarization must never + temporarily replace the shared self.model: doing so would leak the nostream + RunnableBinding to other coroutines mid-flight and break parent logic that + inspects the raw model (profile / _get_ls_params). + """ + import asyncio + + observed_models: list[object] = [] + started = asyncio.Event() + release = asyncio.Event() + + class _BlockingChatModel(_StaticChatModel): + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + # Hold the summary call open so a concurrent run can inspect self.model. + started.set() + await release.wait() + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + model = _BlockingChatModel(text="compressed summary") + middleware = DeerFlowSummarizationMiddleware( + model=model, + trigger=("messages", 4), + keep=("messages", 2), + token_counter=len, + ) + + async def _run() -> None: + summarizing = asyncio.create_task(middleware.abefore_model({"messages": _messages()}, _runtime())) + # Wait until the summary task reaches the blocked LLM call. + await started.wait() + # A concurrent run reads the shared model while summarization is in flight. + observed_models.append(middleware.model) + release.set() + await summarizing + + asyncio.run(_run()) + + assert observed_models == [model] + + +def test_raw_model_is_preserved_for_parent_profile_inspection() -> None: + """self.model must stay the original model so attribute access does not drift.""" + model = _StaticChatModel(text="compressed summary") + middleware = DeerFlowSummarizationMiddleware( + model=model, + trigger=("messages", 4), + keep=("messages", 2), + token_counter=len, + ) + + middleware.before_model({"messages": _messages()}, _runtime()) + + # The shared field is never reassigned to the RunnableBinding. + assert middleware.model is model + assert middleware._summary_model is not model + + +def test_summary_model_preserves_existing_tags_when_adding_nostream() -> None: + """Adding TAG_NOSTREAM must not clobber tags already bound on the model. + + lead_agent/agent.py binds "middleware:summarize" for RunJournal attribution. Because + RunnableBinding.with_config shallow-merges config, the summary model must explicitly + preserve existing tags instead of overwriting them with just [TAG_NOSTREAM]. + """ + tagged_model = _StaticChatModel(text="compressed summary").with_config(tags=["middleware:summarize"]) + middleware = DeerFlowSummarizationMiddleware( + model=tagged_model, + trigger=("messages", 4), + keep=("messages", 2), + token_counter=len, + ) + + summary_tags = middleware._summary_model.config.get("tags") or [] + assert "middleware:summarize" in summary_tags + assert TAG_NOSTREAM in summary_tags + # No duplicate TAG_NOSTREAM even if invoked when one was already present. + assert summary_tags.count(TAG_NOSTREAM) == 1 + + def test_dynamic_context_reminder_is_preserved_across_summarization() -> None: captured: list[SummarizationEvent] = [] middleware = _middleware(before_summarization=[captured.append])