diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py
index df6a453d6..b482fcd39 100644
--- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py
+++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py
@@ -1,14 +1,16 @@
import logging
from langchain.agents import create_agent
-from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
+from langchain.agents.middleware import AgentMiddleware
from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.prompt import apply_prompt_template
+from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
+from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
@@ -17,6 +19,7 @@ from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddlewar
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config
from deerflow.config.app_config import get_app_config
+from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model
@@ -38,7 +41,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name
-def _create_summarization_middleware() -> SummarizationMiddleware | None:
+def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
@@ -77,7 +80,11 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
if config.summary_prompt is not None:
kwargs["summary_prompt"] = config.summary_prompt
- return SummarizationMiddleware(**kwargs)
+ hooks: list[BeforeSummarizationHook] = []
+ if get_memory_config().enabled:
+ hooks.append(memory_flush_hook)
+
+ return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
diff --git a/backend/packages/harness/deerflow/agents/memory/message_processing.py b/backend/packages/harness/deerflow/agents/memory/message_processing.py
new file mode 100644
index 000000000..045829426
--- /dev/null
+++ b/backend/packages/harness/deerflow/agents/memory/message_processing.py
@@ -0,0 +1,109 @@
+"""Shared helpers for turning conversations into memory update inputs."""
+
+from __future__ import annotations
+
+import re
+from copy import copy
+from typing import Any
+
+_UPLOAD_BLOCK_RE = re.compile(r"[\s\S]*?\n*", re.IGNORECASE)
+_CORRECTION_PATTERNS = (
+ re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
+ re.compile(r"\byou misunderstood\b", re.IGNORECASE),
+ re.compile(r"\btry again\b", re.IGNORECASE),
+ re.compile(r"\bredo\b", re.IGNORECASE),
+ re.compile(r"不对"),
+ re.compile(r"你理解错了"),
+ re.compile(r"你理解有误"),
+ re.compile(r"重试"),
+ re.compile(r"重新来"),
+ re.compile(r"换一种"),
+ re.compile(r"改用"),
+)
+_REINFORCEMENT_PATTERNS = (
+ re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
+ re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
+ re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
+ re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
+ re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
+ re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
+ re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
+ re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
+ re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
+ re.compile(r"完全正确(?:[。!?!?.]|$)"),
+ re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
+ re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
+ re.compile(r"继续保持(?:[。!?!?.]|$)"),
+)
+
+
+def extract_message_text(message: Any) -> str:
+ """Extract plain text from message content for filtering and signal detection."""
+ content = getattr(message, "content", "")
+ if isinstance(content, list):
+ text_parts: list[str] = []
+ for part in content:
+ if isinstance(part, str):
+ text_parts.append(part)
+ elif isinstance(part, dict):
+ text_val = part.get("text")
+ if isinstance(text_val, str):
+ text_parts.append(text_val)
+ return " ".join(text_parts)
+ return str(content)
+
+
+def filter_messages_for_memory(messages: list[Any]) -> list[Any]:
+ """Keep only user inputs and final assistant responses for memory updates."""
+ filtered = []
+ skip_next_ai = False
+ for msg in messages:
+ msg_type = getattr(msg, "type", None)
+
+ if msg_type == "human":
+ content_str = extract_message_text(msg)
+ if "" in content_str:
+ stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
+ if not stripped:
+ skip_next_ai = True
+ continue
+ clean_msg = copy(msg)
+ clean_msg.content = stripped
+ filtered.append(clean_msg)
+ skip_next_ai = False
+ else:
+ filtered.append(msg)
+ skip_next_ai = False
+ elif msg_type == "ai":
+ tool_calls = getattr(msg, "tool_calls", None)
+ if not tool_calls:
+ if skip_next_ai:
+ skip_next_ai = False
+ continue
+ filtered.append(msg)
+
+ return filtered
+
+
+def detect_correction(messages: list[Any]) -> bool:
+ """Detect explicit user corrections in recent conversation turns."""
+ recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
+
+ for msg in recent_user_msgs:
+ content = extract_message_text(msg).strip()
+ if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
+ return True
+
+ return False
+
+
+def detect_reinforcement(messages: list[Any]) -> bool:
+ """Detect explicit positive reinforcement signals in recent conversation turns."""
+ recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
+
+ for msg in recent_user_msgs:
+ content = extract_message_text(msg).strip()
+ if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
+ return True
+
+ return False
diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py
index 1db8c63dc..5a7686996 100644
--- a/backend/packages/harness/deerflow/agents/memory/queue.py
+++ b/backend/packages/harness/deerflow/agents/memory/queue.py
@@ -61,48 +61,88 @@ class MemoryUpdateQueue:
return
with self._lock:
- existing_context = next(
- (context for context in self._queue if context.thread_id == thread_id),
- None,
- )
- merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
- merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
- context = ConversationContext(
+ self._enqueue_locked(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
- correction_detected=merged_correction_detected,
- reinforcement_detected=merged_reinforcement_detected,
+ correction_detected=correction_detected,
+ reinforcement_detected=reinforcement_detected,
)
-
- # Check if this thread already has a pending update
- # If so, replace it with the newer one
- self._queue = [c for c in self._queue if c.thread_id != thread_id]
- self._queue.append(context)
-
- # Reset or start the debounce timer
self._reset_timer()
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
+ def add_nowait(
+ self,
+ thread_id: str,
+ messages: list[Any],
+ agent_name: str | None = None,
+ correction_detected: bool = False,
+ reinforcement_detected: bool = False,
+ ) -> None:
+ """Add a conversation and start processing immediately in the background."""
+ config = get_memory_config()
+ if not config.enabled:
+ return
+
+ with self._lock:
+ self._enqueue_locked(
+ thread_id=thread_id,
+ messages=messages,
+ agent_name=agent_name,
+ correction_detected=correction_detected,
+ reinforcement_detected=reinforcement_detected,
+ )
+ self._schedule_timer(0)
+
+ logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue))
+
+ def _enqueue_locked(
+ self,
+ *,
+ thread_id: str,
+ messages: list[Any],
+ agent_name: str | None,
+ correction_detected: bool,
+ reinforcement_detected: bool,
+ ) -> None:
+ existing_context = next(
+ (context for context in self._queue if context.thread_id == thread_id),
+ None,
+ )
+ merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
+ merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
+ context = ConversationContext(
+ thread_id=thread_id,
+ messages=messages,
+ agent_name=agent_name,
+ correction_detected=merged_correction_detected,
+ reinforcement_detected=merged_reinforcement_detected,
+ )
+
+ self._queue = [c for c in self._queue if c.thread_id != thread_id]
+ self._queue.append(context)
+
def _reset_timer(self) -> None:
"""Reset the debounce timer."""
config = get_memory_config()
+ self._schedule_timer(config.debounce_seconds)
+ logger.debug("Memory update timer set for %ss", config.debounce_seconds)
+
+ def _schedule_timer(self, delay_seconds: float) -> None:
+ """Schedule queue processing after the provided delay."""
# Cancel existing timer if any
if self._timer is not None:
self._timer.cancel()
- # Start new timer
self._timer = threading.Timer(
- config.debounce_seconds,
+ delay_seconds,
self._process_queue,
)
self._timer.daemon = True
self._timer.start()
- logger.debug("Memory update timer set for %ss", config.debounce_seconds)
-
def _process_queue(self) -> None:
"""Process all queued conversation contexts."""
# Import here to avoid circular dependency
@@ -110,8 +150,8 @@ class MemoryUpdateQueue:
with self._lock:
if self._processing:
- # Already processing, reschedule
- self._reset_timer()
+ # Preserve immediate flush semantics even if another worker is active.
+ self._schedule_timer(0)
return
if not self._queue:
@@ -164,6 +204,13 @@ class MemoryUpdateQueue:
self._process_queue()
+ def flush_nowait(self) -> None:
+ """Start queue processing immediately in a background thread."""
+ with self._lock:
+ # Daemon thread: queued messages may be lost if the process exits
+ # before _process_queue completes. Acceptable for best-effort memory updates.
+ self._schedule_timer(0)
+
def clear(self) -> None:
"""Clear the queue without processing.
diff --git a/backend/packages/harness/deerflow/agents/memory/summarization_hook.py b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py
new file mode 100644
index 000000000..dafa7d977
--- /dev/null
+++ b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py
@@ -0,0 +1,31 @@
+"""Hooks fired before summarization removes messages from state."""
+
+from __future__ import annotations
+
+from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
+from deerflow.agents.memory.queue import get_memory_queue
+from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
+from deerflow.config.memory_config import get_memory_config
+
+
+def memory_flush_hook(event: SummarizationEvent) -> None:
+ """Flush messages about to be summarized into the memory queue."""
+ if not get_memory_config().enabled or not event.thread_id:
+ return
+
+ filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
+ user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"]
+ assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"]
+ if not user_messages or not assistant_messages:
+ return
+
+ correction_detected = detect_correction(filtered_messages)
+ reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
+ queue = get_memory_queue()
+ queue.add_nowait(
+ thread_id=event.thread_id,
+ messages=filtered_messages,
+ agent_name=event.agent_name,
+ correction_detected=correction_detected,
+ reinforcement_detected=reinforcement_detected,
+ )
diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py
index 5e8ca6344..f1dccf689 100644
--- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py
+++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py
@@ -1,50 +1,19 @@
"""Middleware for memory mechanism."""
import logging
-import re
-from typing import Any, override
+from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
+from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__)
-_UPLOAD_BLOCK_RE = re.compile(r"[\s\S]*?\n*", re.IGNORECASE)
-_CORRECTION_PATTERNS = (
- re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
- re.compile(r"\byou misunderstood\b", re.IGNORECASE),
- re.compile(r"\btry again\b", re.IGNORECASE),
- re.compile(r"\bredo\b", re.IGNORECASE),
- re.compile(r"不对"),
- re.compile(r"你理解错了"),
- re.compile(r"你理解有误"),
- re.compile(r"重试"),
- re.compile(r"重新来"),
- re.compile(r"换一种"),
- re.compile(r"改用"),
-)
-
-_REINFORCEMENT_PATTERNS = (
- re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
- re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
- re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
- re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
- re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
- re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
- re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
- re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
- re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
- re.compile(r"完全正确(?:[。!?!?.]|$)"),
- re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
- re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
- re.compile(r"继续保持(?:[。!?!?.]|$)"),
-)
-
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@@ -52,125 +21,6 @@ class MemoryMiddlewareState(AgentState):
pass
-def _extract_message_text(message: Any) -> str:
- """Extract plain text from message content for filtering and signal detection."""
- content = getattr(message, "content", "")
- if isinstance(content, list):
- text_parts: list[str] = []
- for part in content:
- if isinstance(part, str):
- text_parts.append(part)
- elif isinstance(part, dict):
- text_val = part.get("text")
- if isinstance(text_val, str):
- text_parts.append(text_val)
- return " ".join(text_parts)
- return str(content)
-
-
-def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
- """Filter messages to keep only user inputs and final assistant responses.
-
- This filters out:
- - Tool messages (intermediate tool call results)
- - AI messages with tool_calls (intermediate steps, not final responses)
- - The block injected by UploadsMiddleware into human messages
- (file paths are session-scoped and must not persist in long-term memory).
- The user's actual question is preserved; only turns whose content is entirely
- the upload block (nothing remains after stripping) are dropped along with
- their paired assistant response.
-
- Only keeps:
- - Human messages (with the ephemeral upload block removed)
- - AI messages without tool_calls (final assistant responses), unless the
- paired human turn was upload-only and had no real user text.
-
- Args:
- messages: List of all conversation messages.
-
- Returns:
- Filtered list containing only user inputs and final assistant responses.
- """
- filtered = []
- skip_next_ai = False
- for msg in messages:
- msg_type = getattr(msg, "type", None)
-
- if msg_type == "human":
- content_str = _extract_message_text(msg)
- if "" in content_str:
- # Strip the ephemeral upload block; keep the user's real question.
- stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
- if not stripped:
- # Nothing left — the entire turn was upload bookkeeping;
- # skip it and the paired assistant response.
- skip_next_ai = True
- continue
- # Rebuild the message with cleaned content so the user's question
- # is still available for memory summarisation.
- from copy import copy
-
- clean_msg = copy(msg)
- clean_msg.content = stripped
- filtered.append(clean_msg)
- skip_next_ai = False
- else:
- filtered.append(msg)
- skip_next_ai = False
- elif msg_type == "ai":
- tool_calls = getattr(msg, "tool_calls", None)
- if not tool_calls:
- if skip_next_ai:
- skip_next_ai = False
- continue
- filtered.append(msg)
- # Skip tool messages and AI messages with tool_calls
-
- return filtered
-
-
-def detect_correction(messages: list[Any]) -> bool:
- """Detect explicit user corrections in recent conversation turns.
-
- The queue keeps only one pending context per thread, so callers pass the
- latest filtered message list. Checking only recent user turns keeps signal
- detection conservative while avoiding stale corrections from long histories.
- """
- recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
-
- for msg in recent_user_msgs:
- content = _extract_message_text(msg).strip()
- if not content:
- continue
- if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
- return True
-
- return False
-
-
-def detect_reinforcement(messages: list[Any]) -> bool:
- """Detect explicit positive reinforcement signals in recent conversation turns.
-
- Complements detect_correction() by identifying when the user confirms the
- agent's approach was correct. This allows the memory system to record what
- worked well, not just what went wrong.
-
- The queue keeps only one pending context per thread, so callers pass the
- latest filtered message list. Checking only recent user turns keeps signal
- detection conservative while avoiding stale signals from long histories.
- """
- recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
-
- for msg in recent_user_msgs:
- content = _extract_message_text(msg).strip()
- if not content:
- continue
- if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
- return True
-
- return False
-
-
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@@ -223,7 +73,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
return None
# Filter to only keep user inputs and final assistant responses
- filtered_messages = _filter_messages_for_memory(messages)
+ filtered_messages = filter_messages_for_memory(messages)
# Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response
diff --git a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py
new file mode 100644
index 000000000..fba44c215
--- /dev/null
+++ b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py
@@ -0,0 +1,151 @@
+"""Summarization middleware extensions for DeerFlow."""
+
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass
+from typing import Protocol, runtime_checkable
+
+from langchain.agents import AgentState
+from langchain.agents.middleware import SummarizationMiddleware
+from langchain_core.messages import AnyMessage, RemoveMessage
+from langgraph.config import get_config
+from langgraph.graph.message import REMOVE_ALL_MESSAGES
+from langgraph.runtime import Runtime
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class SummarizationEvent:
+ """Context emitted before conversation history is summarized away."""
+
+ messages_to_summarize: tuple[AnyMessage, ...]
+ preserved_messages: tuple[AnyMessage, ...]
+ thread_id: str | None
+ agent_name: str | None
+ runtime: Runtime
+
+
+@runtime_checkable
+class BeforeSummarizationHook(Protocol):
+ """Hook invoked before summarization removes messages from state."""
+
+ def __call__(self, event: SummarizationEvent) -> None: ...
+
+
+def _resolve_thread_id(runtime: Runtime) -> str | None:
+ """Resolve the current thread ID from runtime context or LangGraph config."""
+ thread_id = runtime.context.get("thread_id") if runtime.context else None
+ if thread_id is None:
+ try:
+ config_data = get_config()
+ except RuntimeError:
+ return None
+ thread_id = config_data.get("configurable", {}).get("thread_id")
+ return thread_id
+
+
+def _resolve_agent_name(runtime: Runtime) -> str | None:
+ """Resolve the current agent name from runtime context or LangGraph config."""
+ agent_name = runtime.context.get("agent_name") if runtime.context else None
+ if agent_name is None:
+ try:
+ config_data = get_config()
+ except RuntimeError:
+ return None
+ agent_name = config_data.get("configurable", {}).get("agent_name")
+ return agent_name
+
+
+class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
+ """Summarization middleware with pre-compression hook dispatch."""
+
+ def __init__(
+ self,
+ *args,
+ before_summarization: list[BeforeSummarizationHook] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self._before_summarization_hooks = before_summarization or []
+
+ def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
+ return self._maybe_summarize(state, runtime)
+
+ async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
+ return await self._amaybe_summarize(state, runtime)
+
+ def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
+ messages = state["messages"]
+ self._ensure_message_ids(messages)
+
+ total_tokens = self.token_counter(messages)
+ if not self._should_summarize(messages, total_tokens):
+ return None
+
+ cutoff_index = self._determine_cutoff_index(messages)
+ if cutoff_index <= 0:
+ return None
+
+ messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
+ self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
+ summary = self._create_summary(messages_to_summarize)
+ new_messages = self._build_new_messages(summary)
+
+ return {
+ "messages": [
+ RemoveMessage(id=REMOVE_ALL_MESSAGES),
+ *new_messages,
+ *preserved_messages,
+ ]
+ }
+
+ async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
+ messages = state["messages"]
+ self._ensure_message_ids(messages)
+
+ total_tokens = self.token_counter(messages)
+ if not self._should_summarize(messages, total_tokens):
+ return None
+
+ cutoff_index = self._determine_cutoff_index(messages)
+ if cutoff_index <= 0:
+ return None
+
+ messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
+ self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
+ summary = await self._acreate_summary(messages_to_summarize)
+ new_messages = self._build_new_messages(summary)
+
+ return {
+ "messages": [
+ RemoveMessage(id=REMOVE_ALL_MESSAGES),
+ *new_messages,
+ *preserved_messages,
+ ]
+ }
+
+ def _fire_hooks(
+ self,
+ messages_to_summarize: list[AnyMessage],
+ preserved_messages: list[AnyMessage],
+ runtime: Runtime,
+ ) -> None:
+ if not self._before_summarization_hooks:
+ return
+
+ event = SummarizationEvent(
+ messages_to_summarize=tuple(messages_to_summarize),
+ preserved_messages=tuple(preserved_messages),
+ thread_id=_resolve_thread_id(runtime),
+ agent_name=_resolve_agent_name(runtime),
+ runtime=runtime,
+ )
+
+ for hook in self._before_summarization_hooks:
+ try:
+ hook(event)
+ except Exception:
+ hook_name = getattr(hook, "__name__", None) or type(hook).__name__
+ logger.exception("before_summarization hook %s failed", hook_name)
diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py
index 9373c2895..1987fd7c4 100644
--- a/backend/tests/test_lead_agent_model_resolution.py
+++ b/backend/tests/test_lead_agent_model_resolution.py
@@ -8,6 +8,7 @@ import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.config.app_config import AppConfig
+from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig
@@ -145,6 +146,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
)
+ monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
captured: dict[str, object] = {}
fake_model = object()
@@ -156,10 +158,32 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
return fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
- monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
+ monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware()
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
assert middleware["model"] is fake_model
+
+
+def test_create_summarization_middleware_registers_memory_flush_hook_when_memory_enabled(monkeypatch):
+ monkeypatch.setattr(
+ lead_agent_module,
+ "get_summarization_config",
+ lambda: SummarizationConfig(enabled=True),
+ )
+ monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=True))
+ monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
+
+ captured: dict[str, object] = {}
+
+ def _fake_middleware(**kwargs):
+ captured.update(kwargs)
+ return kwargs
+
+ monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
+
+ lead_agent_module._create_summarization_middleware()
+
+ assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py
index 204f9d16e..0d991ec0c 100644
--- a/backend/tests/test_memory_queue.py
+++ b/backend/tests/test_memory_queue.py
@@ -1,3 +1,5 @@
+import threading
+import time
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
@@ -89,3 +91,74 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
correction_detected=False,
reinforcement_detected=True,
)
+
+
+def test_flush_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
+ queue = MemoryUpdateQueue()
+ existing_timer = MagicMock()
+ queue._timer = existing_timer
+ created_timer = MagicMock()
+
+ with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
+ queue.flush_nowait()
+
+ existing_timer.cancel.assert_called_once_with()
+ timer_cls.assert_called_once_with(0, queue._process_queue)
+ assert created_timer.daemon is True
+ created_timer.start.assert_called_once_with()
+ assert queue._timer is created_timer
+
+
+def test_add_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
+ queue = MemoryUpdateQueue()
+ existing_timer = MagicMock()
+ queue._timer = existing_timer
+ created_timer = MagicMock()
+
+ with (
+ patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
+ patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls,
+ ):
+ queue.add_nowait(thread_id="thread-1", messages=["conversation"], agent_name="lead-agent")
+
+ existing_timer.cancel.assert_called_once_with()
+ timer_cls.assert_called_once_with(0, queue._process_queue)
+ assert queue.pending_count == 1
+ assert queue._queue[0].agent_name == "lead-agent"
+ assert created_timer.daemon is True
+ created_timer.start.assert_called_once_with()
+
+
+def test_process_queue_reschedules_immediately_when_already_processing() -> None:
+ queue = MemoryUpdateQueue()
+ queue._processing = True
+ created_timer = MagicMock()
+
+ with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
+ queue._process_queue()
+
+ timer_cls.assert_called_once_with(0, queue._process_queue)
+ assert created_timer.daemon is True
+ created_timer.start.assert_called_once_with()
+
+
+def test_flush_nowait_is_non_blocking() -> None:
+ queue = MemoryUpdateQueue()
+ started = threading.Event()
+ finished = threading.Event()
+
+ def _slow_process_queue() -> None:
+ started.set()
+ time.sleep(0.2)
+ finished.set()
+
+ queue._process_queue = _slow_process_queue
+
+ start = time.perf_counter()
+ queue.flush_nowait()
+ elapsed = time.perf_counter() - start
+
+ assert started.wait(0.1) is True
+ assert elapsed < 0.1
+ assert finished.is_set() is False
+ assert finished.wait(1.0) is True
diff --git a/backend/tests/test_memory_upload_filtering.py b/backend/tests/test_memory_upload_filtering.py
index 2e2308b61..6453db6a2 100644
--- a/backend/tests/test_memory_upload_filtering.py
+++ b/backend/tests/test_memory_upload_filtering.py
@@ -3,14 +3,14 @@
Covers two functions introduced to prevent ephemeral file-upload context from
persisting in long-term memory:
- - _filter_messages_for_memory (memory_middleware)
+ - filter_messages_for_memory (message_processing)
- _strip_upload_mentions_from_memory (updater)
"""
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
+from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
-from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
# ---------------------------------------------------------------------------
# Helpers
@@ -31,7 +31,7 @@ def _ai(text: str, tool_calls=None) -> AIMessage:
# ===========================================================================
-# _filter_messages_for_memory
+# filter_messages_for_memory
# ===========================================================================
@@ -45,7 +45,7 @@ class TestFilterMessagesForMemory:
_human(_UPLOAD_BLOCK),
_ai("I have read the file. It says: Hello."),
]
- result = _filter_messages_for_memory(msgs)
+ result = filter_messages_for_memory(msgs)
assert result == []
def test_upload_with_real_question_preserves_question(self):
@@ -56,7 +56,7 @@ class TestFilterMessagesForMemory:
_human(combined),
_ai("The file contains: Hello DeerFlow."),
]
- result = _filter_messages_for_memory(msgs)
+ result = filter_messages_for_memory(msgs)
assert len(result) == 2
human_result = result[0]
@@ -71,7 +71,7 @@ class TestFilterMessagesForMemory:
_human("What is the capital of France?"),
_ai("The capital of France is Paris."),
]
- result = _filter_messages_for_memory(msgs)
+ result = filter_messages_for_memory(msgs)
assert len(result) == 2
assert result[0].content == "What is the capital of France?"
assert result[1].content == "The capital of France is Paris."
@@ -84,7 +84,7 @@ class TestFilterMessagesForMemory:
ToolMessage(content="Search results", tool_call_id="1"),
_ai("Here are the results."),
]
- result = _filter_messages_for_memory(msgs)
+ result = filter_messages_for_memory(msgs)
human_msgs = [m for m in result if m.type == "human"]
ai_msgs = [m for m in result if m.type == "ai"]
assert len(human_msgs) == 1
@@ -101,7 +101,7 @@ class TestFilterMessagesForMemory:
_human("What is 2 + 2?"),
_ai("4"),
]
- result = _filter_messages_for_memory(msgs)
+ result = filter_messages_for_memory(msgs)
human_contents = [m.content for m in result if m.type == "human"]
ai_contents = [m.content for m in result if m.type == "ai"]
@@ -121,14 +121,14 @@ class TestFilterMessagesForMemory:
]
)
msgs = [msg, _ai("Done.")]
- result = _filter_messages_for_memory(msgs)
+ result = filter_messages_for_memory(msgs)
assert result == []
def test_file_path_not_in_filtered_content(self):
"""After filtering, no upload file path should appear in any message."""
combined = _UPLOAD_BLOCK + "\n\nSummarise the file please."
msgs = [_human(combined), _ai("It says hello.")]
- result = _filter_messages_for_memory(msgs)
+ result = filter_messages_for_memory(msgs)
all_content = " ".join(m.content for m in result if isinstance(m.content, str))
assert "/mnt/user-data/uploads/" not in all_content
assert "" not in all_content
diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py
new file mode 100644
index 000000000..d327c94c4
--- /dev/null
+++ b/backend/tests/test_summarization_middleware.py
@@ -0,0 +1,186 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+from unittest.mock import MagicMock
+
+import pytest
+from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
+
+from deerflow.agents.memory.summarization_hook import memory_flush_hook
+from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
+from deerflow.config.memory_config import MemoryConfig
+
+
+def _messages() -> list:
+ return [
+ HumanMessage(content="user-1"),
+ AIMessage(content="assistant-1"),
+ HumanMessage(content="user-2"),
+ AIMessage(content="assistant-2"),
+ ]
+
+
+def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
+ context = {}
+ if thread_id is not None:
+ context["thread_id"] = thread_id
+ if agent_name is not None:
+ context["agent_name"] = agent_name
+ return SimpleNamespace(context=context)
+
+
+def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware:
+ model = MagicMock()
+ model.invoke.return_value = SimpleNamespace(text="compressed summary")
+ return DeerFlowSummarizationMiddleware(
+ model=model,
+ trigger=trigger,
+ keep=keep,
+ token_counter=len,
+ before_summarization=before_summarization,
+ )
+
+
+def test_before_summarization_hook_receives_messages_before_compression() -> None:
+ captured: list[SummarizationEvent] = []
+ middleware = _middleware(before_summarization=[captured.append])
+
+ result = middleware.before_model({"messages": _messages()}, _runtime())
+
+ assert len(captured) == 1
+ assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
+ assert [message.content for message in captured[0].preserved_messages] == ["user-2", "assistant-2"]
+ assert captured[0].thread_id == "thread-1"
+ assert captured[0].agent_name is None
+ assert isinstance(result["messages"][0], RemoveMessage)
+ assert result["messages"][1].content.startswith("Here is a summary")
+
+
+def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
+ captured: list[SummarizationEvent] = []
+ middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
+
+ result = middleware.before_model({"messages": _messages()}, _runtime())
+
+ assert captured == []
+ assert result is None
+
+
+def test_before_summarization_hook_exception_does_not_block_compression(caplog: pytest.LogCaptureFixture) -> None:
+ def _broken_hook(_: SummarizationEvent) -> None:
+ raise RuntimeError("hook failure")
+
+ middleware = _middleware(before_summarization=[_broken_hook])
+
+ with caplog.at_level("ERROR"):
+ result = middleware.before_model({"messages": _messages()}, _runtime())
+
+ assert "before_summarization hook _broken_hook failed" in caplog.text
+ assert isinstance(result["messages"][0], RemoveMessage)
+
+
+def test_multiple_before_summarization_hooks_run_in_registration_order() -> None:
+ call_order: list[str] = []
+
+ def _hook(name: str):
+ return lambda _: call_order.append(name)
+
+ middleware = _middleware(before_summarization=[_hook("first"), _hook("second"), _hook("third")])
+
+ middleware.before_model({"messages": _messages()}, _runtime())
+
+ assert call_order == ["first", "second", "third"]
+
+
+@pytest.mark.anyio
+async def test_abefore_model_calls_hooks_same_as_sync() -> None:
+ captured: list[SummarizationEvent] = []
+ middleware = _middleware(before_summarization=[captured.append])
+
+ await middleware.abefore_model({"messages": _messages()}, _runtime())
+
+ assert len(captured) == 1
+ assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
+
+
+def test_memory_flush_hook_skips_when_memory_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
+ queue = MagicMock()
+ monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=False))
+ monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
+
+ memory_flush_hook(
+ SummarizationEvent(
+ messages_to_summarize=tuple(_messages()[:2]),
+ preserved_messages=(),
+ thread_id="thread-1",
+ agent_name=None,
+ runtime=_runtime(),
+ )
+ )
+
+ queue.add_nowait.assert_not_called()
+
+
+def test_memory_flush_hook_skips_when_thread_id_missing(monkeypatch: pytest.MonkeyPatch) -> None:
+ queue = MagicMock()
+ monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
+ monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
+
+ memory_flush_hook(
+ SummarizationEvent(
+ messages_to_summarize=tuple(_messages()[:2]),
+ preserved_messages=(),
+ thread_id=None,
+ agent_name=None,
+ runtime=_runtime(None),
+ )
+ )
+
+ queue.add_nowait.assert_not_called()
+
+
+def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: pytest.MonkeyPatch) -> None:
+ queue = MagicMock()
+ messages = [
+ HumanMessage(content="Question"),
+ AIMessage(content="Calling tool", tool_calls=[{"name": "search", "id": "tool-1", "args": {}}]),
+ AIMessage(content="Final answer"),
+ ]
+ monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
+ monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
+
+ memory_flush_hook(
+ SummarizationEvent(
+ messages_to_summarize=tuple(messages),
+ preserved_messages=(),
+ thread_id="thread-1",
+ agent_name=None,
+ runtime=_runtime(),
+ )
+ )
+
+ queue.add_nowait.assert_called_once()
+ add_kwargs = queue.add_nowait.call_args.kwargs
+ assert add_kwargs["thread_id"] == "thread-1"
+ assert [message.content for message in add_kwargs["messages"]] == ["Question", "Final answer"]
+ assert add_kwargs["correction_detected"] is False
+ assert add_kwargs["reinforcement_detected"] is False
+
+
+def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
+ queue = MagicMock()
+ monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
+ monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
+
+ memory_flush_hook(
+ SummarizationEvent(
+ messages_to_summarize=tuple(_messages()[:2]),
+ preserved_messages=(),
+ thread_id="thread-1",
+ agent_name="research-agent",
+ runtime=_runtime(agent_name="research-agent"),
+ )
+ )
+
+ queue.add_nowait.assert_called_once()
+ assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"