fix(summarization): tag summary LLM calls nostream to stop phantom stream messages (#2503) (#3378)

* fix(summarization): tag summary LLM calls nostream to stop phantom stream messages (#2503)

The SummarizationMiddleware runs its summary LLM call inside a before_model
hook. Without a nostream tag the summary tokens were captured by LangGraph's
messages-tuple stream callback and broadcast to the frontend as a phantom AI
message.

Generate a dedicated summary model copy tagged with "nostream" (merged on top
of any existing tags such as "middleware:summarize" so RunJournal attribution
is preserved) and override _create_summary / _acreate_summary to invoke it
directly. This avoids temporarily swapping the shared self.model, which would
otherwise leak the RunnableBinding across concurrent runs and break parent
logic that inspects the raw model (profile / _get_ls_params).

Add regression tests covering nostream tagging, concurrent-run isolation, raw
model preservation, and existing-tag merge.

* fix(summarization): address nostream review feedback
This commit is contained in:
Ryker_Feng 2026-06-07 17:55:04 +08:00 committed by GitHub
parent 88e36d9686
commit d133b1119a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 185 additions and 1 deletions

View File

@ -9,8 +9,9 @@ from typing import Any, Protocol, override, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage, get_buffer_string
from langgraph.config import get_config
from langgraph.constants import TAG_NOSTREAM
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
@ -116,6 +117,74 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
# The summary LLM call runs inside a LangGraph middleware hook, so its token
# stream would otherwise be captured by the messages-tuple stream callback and
# broadcast to the frontend as a phantom AI message. Tag a dedicated model copy
# with TAG_NOSTREAM so the streaming handler skips it.
# Keep self.model untagged so the parent's profile / ls_params inspection still works.
#
# Preserve any tags already bound on the model (e.g. "middleware:summarize" set in
# lead_agent/agent.py for RunJournal attribution): RunnableBinding.with_config does a
# shallow merge that would otherwise overwrite the existing tags list entirely.
existing_tags = list((getattr(self.model, "config", None) or {}).get("tags") or [])
merged_tags = [*existing_tags, TAG_NOSTREAM] if TAG_NOSTREAM not in existing_tags else existing_tags
self._summary_model = self.model.with_config(tags=merged_tags)
@override
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
return self._summarize_with(messages_to_summarize)
@override
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
return await self._asummarize_with(messages_to_summarize)
def _summarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Mirror the parent ``_create_summary`` but invoke the nostream-tagged model.
We do not swap ``self.model`` at the instance level: the agent/middleware is
cached and reused across concurrent runs, so a temporary swap would leak the
``RunnableBinding`` to other coroutines during ``await`` and break parent logic
that inspects the raw model (``profile`` / ``_get_ls_params``).
"""
if not messages_to_summarize:
return "No previous conversation history."
prompt = self._build_summary_prompt(messages_to_summarize)
if prompt is None:
return "Previous conversation was too long to summarize."
try:
response = self._summary_model.invoke(
prompt,
config={"metadata": {"lc_source": "summarization"}},
)
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
async def _asummarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Async counterpart of :meth:`_summarize_with` using the nostream model."""
if not messages_to_summarize:
return "No previous conversation history."
prompt = self._build_summary_prompt(messages_to_summarize)
if prompt is None:
return "Previous conversation was too long to summarize."
try:
response = await self._summary_model.ainvoke(
prompt,
config={"metadata": {"lc_source": "summarization"}},
)
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
def _build_summary_prompt(self, messages_to_summarize: list[AnyMessage]) -> str | None:
"""Build the summary prompt, returning ``None`` when trimming leaves nothing."""
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return None
# Format messages to avoid token inflation from metadata when str() is called on
# message objects.
formatted_messages = get_buffer_string(trimmed_messages)
return self.summary_prompt.format(messages=formatted_messages).rstrip()
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime)

View File

@ -9,6 +9,7 @@ from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.constants import TAG_NOSTREAM
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
@ -160,6 +161,120 @@ def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() ->
assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary")
def test_summary_model_is_tagged_nostream_to_avoid_stream_pollution() -> None:
tags_during_summary: list[list[str]] = []
class _RecordingChatModel(_StaticChatModel):
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
tags_during_summary.append(list(run_manager.tags) if run_manager else [])
return super()._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
model = _RecordingChatModel(text="compressed summary")
middleware = DeerFlowSummarizationMiddleware(
model=model,
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
# The dedicated summary model must carry TAG_NOSTREAM so LangGraph's
# messages-tuple stream handler skips its tokens, while the raw model used by
# the parent for profile / token inspection stays untagged.
assert TAG_NOSTREAM in (middleware._summary_model.config.get("tags") or [])
assert TAG_NOSTREAM not in (getattr(middleware.model, "config", {}).get("tags") or [])
result = middleware.before_model({"messages": _messages()}, _runtime())
# The summary LLM call must actually run with the nostream tag (this is what the
# stream handler inspects), and the shared self.model must remain the raw,
# untagged model so parent logic (profile / _get_ls_params) keeps working.
assert tags_during_summary == [[TAG_NOSTREAM]]
assert middleware.model is model
assert result["messages"][1].content.startswith("Here is a summary")
def test_summarization_does_not_mutate_shared_model_across_concurrent_runs() -> None:
"""Concurrent runs must not observe a swapped-out self.model during summarization.
The agent/middleware instance is cached and reused, so summarization must never
temporarily replace the shared self.model: doing so would leak the nostream
RunnableBinding to other coroutines mid-flight and break parent logic that
inspects the raw model (profile / _get_ls_params).
"""
import asyncio
observed_models: list[object] = []
started = asyncio.Event()
release = asyncio.Event()
class _BlockingChatModel(_StaticChatModel):
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
# Hold the summary call open so a concurrent run can inspect self.model.
started.set()
await release.wait()
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
model = _BlockingChatModel(text="compressed summary")
middleware = DeerFlowSummarizationMiddleware(
model=model,
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
async def _run() -> None:
summarizing = asyncio.create_task(middleware.abefore_model({"messages": _messages()}, _runtime()))
# Wait until the summary task reaches the blocked LLM call.
await started.wait()
# A concurrent run reads the shared model while summarization is in flight.
observed_models.append(middleware.model)
release.set()
await summarizing
asyncio.run(_run())
assert observed_models == [model]
def test_raw_model_is_preserved_for_parent_profile_inspection() -> None:
"""self.model must stay the original model so attribute access does not drift."""
model = _StaticChatModel(text="compressed summary")
middleware = DeerFlowSummarizationMiddleware(
model=model,
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
middleware.before_model({"messages": _messages()}, _runtime())
# The shared field is never reassigned to the RunnableBinding.
assert middleware.model is model
assert middleware._summary_model is not model
def test_summary_model_preserves_existing_tags_when_adding_nostream() -> None:
"""Adding TAG_NOSTREAM must not clobber tags already bound on the model.
lead_agent/agent.py binds "middleware:summarize" for RunJournal attribution. Because
RunnableBinding.with_config shallow-merges config, the summary model must explicitly
preserve existing tags instead of overwriting them with just [TAG_NOSTREAM].
"""
tagged_model = _StaticChatModel(text="compressed summary").with_config(tags=["middleware:summarize"])
middleware = DeerFlowSummarizationMiddleware(
model=tagged_model,
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
summary_tags = middleware._summary_model.config.get("tags") or []
assert "middleware:summarize" in summary_tags
assert TAG_NOSTREAM in summary_tags
# No duplicate TAG_NOSTREAM even if invoked when one was already present.
assert summary_tags.count(TAG_NOSTREAM) == 1
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])