diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index df6a453d6..b482fcd39 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -1,14 +1,16 @@ import logging from langchain.agents import create_agent -from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware +from langchain.agents.middleware import AgentMiddleware from langchain_core.runnables import RunnableConfig from deerflow.agents.lead_agent.prompt import apply_prompt_template +from deerflow.agents.memory.summarization_hook import memory_flush_hook from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware +from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware from deerflow.agents.middlewares.todo_middleware import TodoMiddleware from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware @@ -17,6 +19,7 @@ from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddlewar from deerflow.agents.thread_state import ThreadState from deerflow.config.agents_config import load_agent_config from deerflow.config.app_config import get_app_config +from deerflow.config.memory_config import get_memory_config from deerflow.config.summarization_config import get_summarization_config from deerflow.models import create_chat_model @@ -38,7 +41,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str: return default_model_name -def _create_summarization_middleware() -> SummarizationMiddleware | None: +def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None: """Create and configure the summarization middleware from config.""" config = get_summarization_config() @@ -77,7 +80,11 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None: if config.summary_prompt is not None: kwargs["summary_prompt"] = config.summary_prompt - return SummarizationMiddleware(**kwargs) + hooks: list[BeforeSummarizationHook] = [] + if get_memory_config().enabled: + hooks.append(memory_flush_hook) + + return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks) def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None: diff --git a/backend/packages/harness/deerflow/agents/memory/message_processing.py b/backend/packages/harness/deerflow/agents/memory/message_processing.py new file mode 100644 index 000000000..045829426 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/memory/message_processing.py @@ -0,0 +1,109 @@ +"""Shared helpers for turning conversations into memory update inputs.""" + +from __future__ import annotations + +import re +from copy import copy +from typing import Any + +_UPLOAD_BLOCK_RE = re.compile(r"[\s\S]*?\n*", re.IGNORECASE) +_CORRECTION_PATTERNS = ( + re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE), + re.compile(r"\byou misunderstood\b", re.IGNORECASE), + re.compile(r"\btry again\b", re.IGNORECASE), + re.compile(r"\bredo\b", re.IGNORECASE), + re.compile(r"不对"), + re.compile(r"你理解错了"), + re.compile(r"你理解有误"), + re.compile(r"重试"), + re.compile(r"重新来"), + re.compile(r"换一种"), + re.compile(r"改用"), +) +_REINFORCEMENT_PATTERNS = ( + re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE), + re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE), + re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE), + re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE), + re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE), + re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE), + re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE), + re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE), + re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"), + re.compile(r"完全正确(?:[。!?!?.]|$)"), + re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"), + re.compile(r"正是我想要的(?:[。!?!?.]|$)"), + re.compile(r"继续保持(?:[。!?!?.]|$)"), +) + + +def extract_message_text(message: Any) -> str: + """Extract plain text from message content for filtering and signal detection.""" + content = getattr(message, "content", "") + if isinstance(content, list): + text_parts: list[str] = [] + for part in content: + if isinstance(part, str): + text_parts.append(part) + elif isinstance(part, dict): + text_val = part.get("text") + if isinstance(text_val, str): + text_parts.append(text_val) + return " ".join(text_parts) + return str(content) + + +def filter_messages_for_memory(messages: list[Any]) -> list[Any]: + """Keep only user inputs and final assistant responses for memory updates.""" + filtered = [] + skip_next_ai = False + for msg in messages: + msg_type = getattr(msg, "type", None) + + if msg_type == "human": + content_str = extract_message_text(msg) + if "" in content_str: + stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip() + if not stripped: + skip_next_ai = True + continue + clean_msg = copy(msg) + clean_msg.content = stripped + filtered.append(clean_msg) + skip_next_ai = False + else: + filtered.append(msg) + skip_next_ai = False + elif msg_type == "ai": + tool_calls = getattr(msg, "tool_calls", None) + if not tool_calls: + if skip_next_ai: + skip_next_ai = False + continue + filtered.append(msg) + + return filtered + + +def detect_correction(messages: list[Any]) -> bool: + """Detect explicit user corrections in recent conversation turns.""" + recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"] + + for msg in recent_user_msgs: + content = extract_message_text(msg).strip() + if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS): + return True + + return False + + +def detect_reinforcement(messages: list[Any]) -> bool: + """Detect explicit positive reinforcement signals in recent conversation turns.""" + recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"] + + for msg in recent_user_msgs: + content = extract_message_text(msg).strip() + if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS): + return True + + return False diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index 1db8c63dc..5a7686996 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -61,48 +61,88 @@ class MemoryUpdateQueue: return with self._lock: - existing_context = next( - (context for context in self._queue if context.thread_id == thread_id), - None, - ) - merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False) - merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False) - context = ConversationContext( + self._enqueue_locked( thread_id=thread_id, messages=messages, agent_name=agent_name, - correction_detected=merged_correction_detected, - reinforcement_detected=merged_reinforcement_detected, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, ) - - # Check if this thread already has a pending update - # If so, replace it with the newer one - self._queue = [c for c in self._queue if c.thread_id != thread_id] - self._queue.append(context) - - # Reset or start the debounce timer self._reset_timer() logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue)) + def add_nowait( + self, + thread_id: str, + messages: list[Any], + agent_name: str | None = None, + correction_detected: bool = False, + reinforcement_detected: bool = False, + ) -> None: + """Add a conversation and start processing immediately in the background.""" + config = get_memory_config() + if not config.enabled: + return + + with self._lock: + self._enqueue_locked( + thread_id=thread_id, + messages=messages, + agent_name=agent_name, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, + ) + self._schedule_timer(0) + + logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue)) + + def _enqueue_locked( + self, + *, + thread_id: str, + messages: list[Any], + agent_name: str | None, + correction_detected: bool, + reinforcement_detected: bool, + ) -> None: + existing_context = next( + (context for context in self._queue if context.thread_id == thread_id), + None, + ) + merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False) + merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False) + context = ConversationContext( + thread_id=thread_id, + messages=messages, + agent_name=agent_name, + correction_detected=merged_correction_detected, + reinforcement_detected=merged_reinforcement_detected, + ) + + self._queue = [c for c in self._queue if c.thread_id != thread_id] + self._queue.append(context) + def _reset_timer(self) -> None: """Reset the debounce timer.""" config = get_memory_config() + self._schedule_timer(config.debounce_seconds) + logger.debug("Memory update timer set for %ss", config.debounce_seconds) + + def _schedule_timer(self, delay_seconds: float) -> None: + """Schedule queue processing after the provided delay.""" # Cancel existing timer if any if self._timer is not None: self._timer.cancel() - # Start new timer self._timer = threading.Timer( - config.debounce_seconds, + delay_seconds, self._process_queue, ) self._timer.daemon = True self._timer.start() - logger.debug("Memory update timer set for %ss", config.debounce_seconds) - def _process_queue(self) -> None: """Process all queued conversation contexts.""" # Import here to avoid circular dependency @@ -110,8 +150,8 @@ class MemoryUpdateQueue: with self._lock: if self._processing: - # Already processing, reschedule - self._reset_timer() + # Preserve immediate flush semantics even if another worker is active. + self._schedule_timer(0) return if not self._queue: @@ -164,6 +204,13 @@ class MemoryUpdateQueue: self._process_queue() + def flush_nowait(self) -> None: + """Start queue processing immediately in a background thread.""" + with self._lock: + # Daemon thread: queued messages may be lost if the process exits + # before _process_queue completes. Acceptable for best-effort memory updates. + self._schedule_timer(0) + def clear(self) -> None: """Clear the queue without processing. diff --git a/backend/packages/harness/deerflow/agents/memory/summarization_hook.py b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py new file mode 100644 index 000000000..dafa7d977 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py @@ -0,0 +1,31 @@ +"""Hooks fired before summarization removes messages from state.""" + +from __future__ import annotations + +from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory +from deerflow.agents.memory.queue import get_memory_queue +from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent +from deerflow.config.memory_config import get_memory_config + + +def memory_flush_hook(event: SummarizationEvent) -> None: + """Flush messages about to be summarized into the memory queue.""" + if not get_memory_config().enabled or not event.thread_id: + return + + filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize)) + user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"] + assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"] + if not user_messages or not assistant_messages: + return + + correction_detected = detect_correction(filtered_messages) + reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) + queue = get_memory_queue() + queue.add_nowait( + thread_id=event.thread_id, + messages=filtered_messages, + agent_name=event.agent_name, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, + ) diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py index 5e8ca6344..f1dccf689 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py @@ -1,50 +1,19 @@ """Middleware for memory mechanism.""" import logging -import re -from typing import Any, override +from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.config import get_config from langgraph.runtime import Runtime +from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory from deerflow.agents.memory.queue import get_memory_queue from deerflow.config.memory_config import get_memory_config logger = logging.getLogger(__name__) -_UPLOAD_BLOCK_RE = re.compile(r"[\s\S]*?\n*", re.IGNORECASE) -_CORRECTION_PATTERNS = ( - re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE), - re.compile(r"\byou misunderstood\b", re.IGNORECASE), - re.compile(r"\btry again\b", re.IGNORECASE), - re.compile(r"\bredo\b", re.IGNORECASE), - re.compile(r"不对"), - re.compile(r"你理解错了"), - re.compile(r"你理解有误"), - re.compile(r"重试"), - re.compile(r"重新来"), - re.compile(r"换一种"), - re.compile(r"改用"), -) - -_REINFORCEMENT_PATTERNS = ( - re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE), - re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE), - re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE), - re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE), - re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE), - re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE), - re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE), - re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE), - re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"), - re.compile(r"完全正确(?:[。!?!?.]|$)"), - re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"), - re.compile(r"正是我想要的(?:[。!?!?.]|$)"), - re.compile(r"继续保持(?:[。!?!?.]|$)"), -) - class MemoryMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" @@ -52,125 +21,6 @@ class MemoryMiddlewareState(AgentState): pass -def _extract_message_text(message: Any) -> str: - """Extract plain text from message content for filtering and signal detection.""" - content = getattr(message, "content", "") - if isinstance(content, list): - text_parts: list[str] = [] - for part in content: - if isinstance(part, str): - text_parts.append(part) - elif isinstance(part, dict): - text_val = part.get("text") - if isinstance(text_val, str): - text_parts.append(text_val) - return " ".join(text_parts) - return str(content) - - -def _filter_messages_for_memory(messages: list[Any]) -> list[Any]: - """Filter messages to keep only user inputs and final assistant responses. - - This filters out: - - Tool messages (intermediate tool call results) - - AI messages with tool_calls (intermediate steps, not final responses) - - The block injected by UploadsMiddleware into human messages - (file paths are session-scoped and must not persist in long-term memory). - The user's actual question is preserved; only turns whose content is entirely - the upload block (nothing remains after stripping) are dropped along with - their paired assistant response. - - Only keeps: - - Human messages (with the ephemeral upload block removed) - - AI messages without tool_calls (final assistant responses), unless the - paired human turn was upload-only and had no real user text. - - Args: - messages: List of all conversation messages. - - Returns: - Filtered list containing only user inputs and final assistant responses. - """ - filtered = [] - skip_next_ai = False - for msg in messages: - msg_type = getattr(msg, "type", None) - - if msg_type == "human": - content_str = _extract_message_text(msg) - if "" in content_str: - # Strip the ephemeral upload block; keep the user's real question. - stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip() - if not stripped: - # Nothing left — the entire turn was upload bookkeeping; - # skip it and the paired assistant response. - skip_next_ai = True - continue - # Rebuild the message with cleaned content so the user's question - # is still available for memory summarisation. - from copy import copy - - clean_msg = copy(msg) - clean_msg.content = stripped - filtered.append(clean_msg) - skip_next_ai = False - else: - filtered.append(msg) - skip_next_ai = False - elif msg_type == "ai": - tool_calls = getattr(msg, "tool_calls", None) - if not tool_calls: - if skip_next_ai: - skip_next_ai = False - continue - filtered.append(msg) - # Skip tool messages and AI messages with tool_calls - - return filtered - - -def detect_correction(messages: list[Any]) -> bool: - """Detect explicit user corrections in recent conversation turns. - - The queue keeps only one pending context per thread, so callers pass the - latest filtered message list. Checking only recent user turns keeps signal - detection conservative while avoiding stale corrections from long histories. - """ - recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"] - - for msg in recent_user_msgs: - content = _extract_message_text(msg).strip() - if not content: - continue - if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS): - return True - - return False - - -def detect_reinforcement(messages: list[Any]) -> bool: - """Detect explicit positive reinforcement signals in recent conversation turns. - - Complements detect_correction() by identifying when the user confirms the - agent's approach was correct. This allows the memory system to record what - worked well, not just what went wrong. - - The queue keeps only one pending context per thread, so callers pass the - latest filtered message list. Checking only recent user turns keeps signal - detection conservative while avoiding stale signals from long histories. - """ - recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"] - - for msg in recent_user_msgs: - content = _extract_message_text(msg).strip() - if not content: - continue - if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS): - return True - - return False - - class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): """Middleware that queues conversation for memory update after agent execution. @@ -223,7 +73,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): return None # Filter to only keep user inputs and final assistant responses - filtered_messages = _filter_messages_for_memory(messages) + filtered_messages = filter_messages_for_memory(messages) # Only queue if there's meaningful conversation # At minimum need one user message and one assistant response diff --git a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py new file mode 100644 index 000000000..fba44c215 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py @@ -0,0 +1,151 @@ +"""Summarization middleware extensions for DeerFlow.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + +from langchain.agents import AgentState +from langchain.agents.middleware import SummarizationMiddleware +from langchain_core.messages import AnyMessage, RemoveMessage +from langgraph.config import get_config +from langgraph.graph.message import REMOVE_ALL_MESSAGES +from langgraph.runtime import Runtime + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class SummarizationEvent: + """Context emitted before conversation history is summarized away.""" + + messages_to_summarize: tuple[AnyMessage, ...] + preserved_messages: tuple[AnyMessage, ...] + thread_id: str | None + agent_name: str | None + runtime: Runtime + + +@runtime_checkable +class BeforeSummarizationHook(Protocol): + """Hook invoked before summarization removes messages from state.""" + + def __call__(self, event: SummarizationEvent) -> None: ... + + +def _resolve_thread_id(runtime: Runtime) -> str | None: + """Resolve the current thread ID from runtime context or LangGraph config.""" + thread_id = runtime.context.get("thread_id") if runtime.context else None + if thread_id is None: + try: + config_data = get_config() + except RuntimeError: + return None + thread_id = config_data.get("configurable", {}).get("thread_id") + return thread_id + + +def _resolve_agent_name(runtime: Runtime) -> str | None: + """Resolve the current agent name from runtime context or LangGraph config.""" + agent_name = runtime.context.get("agent_name") if runtime.context else None + if agent_name is None: + try: + config_data = get_config() + except RuntimeError: + return None + agent_name = config_data.get("configurable", {}).get("agent_name") + return agent_name + + +class DeerFlowSummarizationMiddleware(SummarizationMiddleware): + """Summarization middleware with pre-compression hook dispatch.""" + + def __init__( + self, + *args, + before_summarization: list[BeforeSummarizationHook] | None = None, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self._before_summarization_hooks = before_summarization or [] + + def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._maybe_summarize(state, runtime) + + async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return await self._amaybe_summarize(state, runtime) + + def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None: + messages = state["messages"] + self._ensure_message_ids(messages) + + total_tokens = self.token_counter(messages) + if not self._should_summarize(messages, total_tokens): + return None + + cutoff_index = self._determine_cutoff_index(messages) + if cutoff_index <= 0: + return None + + messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) + self._fire_hooks(messages_to_summarize, preserved_messages, runtime) + summary = self._create_summary(messages_to_summarize) + new_messages = self._build_new_messages(summary) + + return { + "messages": [ + RemoveMessage(id=REMOVE_ALL_MESSAGES), + *new_messages, + *preserved_messages, + ] + } + + async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None: + messages = state["messages"] + self._ensure_message_ids(messages) + + total_tokens = self.token_counter(messages) + if not self._should_summarize(messages, total_tokens): + return None + + cutoff_index = self._determine_cutoff_index(messages) + if cutoff_index <= 0: + return None + + messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) + self._fire_hooks(messages_to_summarize, preserved_messages, runtime) + summary = await self._acreate_summary(messages_to_summarize) + new_messages = self._build_new_messages(summary) + + return { + "messages": [ + RemoveMessage(id=REMOVE_ALL_MESSAGES), + *new_messages, + *preserved_messages, + ] + } + + def _fire_hooks( + self, + messages_to_summarize: list[AnyMessage], + preserved_messages: list[AnyMessage], + runtime: Runtime, + ) -> None: + if not self._before_summarization_hooks: + return + + event = SummarizationEvent( + messages_to_summarize=tuple(messages_to_summarize), + preserved_messages=tuple(preserved_messages), + thread_id=_resolve_thread_id(runtime), + agent_name=_resolve_agent_name(runtime), + runtime=runtime, + ) + + for hook in self._before_summarization_hooks: + try: + hook(event) + except Exception: + hook_name = getattr(hook, "__name__", None) or type(hook).__name__ + logger.exception("before_summarization hook %s failed", hook_name) diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index 9373c2895..1987fd7c4 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -8,6 +8,7 @@ import pytest from deerflow.agents.lead_agent import agent as lead_agent_module from deerflow.config.app_config import AppConfig +from deerflow.config.memory_config import MemoryConfig from deerflow.config.model_config import ModelConfig from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.summarization_config import SummarizationConfig @@ -145,6 +146,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch "get_summarization_config", lambda: SummarizationConfig(enabled=True, model_name="model-masswork"), ) + monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False)) captured: dict[str, object] = {} fake_model = object() @@ -156,10 +158,32 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch return fake_model monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) - monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs) + monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs) middleware = lead_agent_module._create_summarization_middleware() assert captured["name"] == "model-masswork" assert captured["thinking_enabled"] is False assert middleware["model"] is fake_model + + +def test_create_summarization_middleware_registers_memory_flush_hook_when_memory_enabled(monkeypatch): + monkeypatch.setattr( + lead_agent_module, + "get_summarization_config", + lambda: SummarizationConfig(enabled=True), + ) + monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=True)) + monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object()) + + captured: dict[str, object] = {} + + def _fake_middleware(**kwargs): + captured.update(kwargs) + return kwargs + + monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware) + + lead_agent_module._create_summarization_middleware() + + assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook] diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py index 204f9d16e..0d991ec0c 100644 --- a/backend/tests/test_memory_queue.py +++ b/backend/tests/test_memory_queue.py @@ -1,3 +1,5 @@ +import threading +import time from unittest.mock import MagicMock, patch from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue @@ -89,3 +91,74 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None: correction_detected=False, reinforcement_detected=True, ) + + +def test_flush_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None: + queue = MemoryUpdateQueue() + existing_timer = MagicMock() + queue._timer = existing_timer + created_timer = MagicMock() + + with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls: + queue.flush_nowait() + + existing_timer.cancel.assert_called_once_with() + timer_cls.assert_called_once_with(0, queue._process_queue) + assert created_timer.daemon is True + created_timer.start.assert_called_once_with() + assert queue._timer is created_timer + + +def test_add_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None: + queue = MemoryUpdateQueue() + existing_timer = MagicMock() + queue._timer = existing_timer + created_timer = MagicMock() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls, + ): + queue.add_nowait(thread_id="thread-1", messages=["conversation"], agent_name="lead-agent") + + existing_timer.cancel.assert_called_once_with() + timer_cls.assert_called_once_with(0, queue._process_queue) + assert queue.pending_count == 1 + assert queue._queue[0].agent_name == "lead-agent" + assert created_timer.daemon is True + created_timer.start.assert_called_once_with() + + +def test_process_queue_reschedules_immediately_when_already_processing() -> None: + queue = MemoryUpdateQueue() + queue._processing = True + created_timer = MagicMock() + + with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls: + queue._process_queue() + + timer_cls.assert_called_once_with(0, queue._process_queue) + assert created_timer.daemon is True + created_timer.start.assert_called_once_with() + + +def test_flush_nowait_is_non_blocking() -> None: + queue = MemoryUpdateQueue() + started = threading.Event() + finished = threading.Event() + + def _slow_process_queue() -> None: + started.set() + time.sleep(0.2) + finished.set() + + queue._process_queue = _slow_process_queue + + start = time.perf_counter() + queue.flush_nowait() + elapsed = time.perf_counter() - start + + assert started.wait(0.1) is True + assert elapsed < 0.1 + assert finished.is_set() is False + assert finished.wait(1.0) is True diff --git a/backend/tests/test_memory_upload_filtering.py b/backend/tests/test_memory_upload_filtering.py index 2e2308b61..6453db6a2 100644 --- a/backend/tests/test_memory_upload_filtering.py +++ b/backend/tests/test_memory_upload_filtering.py @@ -3,14 +3,14 @@ Covers two functions introduced to prevent ephemeral file-upload context from persisting in long-term memory: - - _filter_messages_for_memory (memory_middleware) + - filter_messages_for_memory (message_processing) - _strip_upload_mentions_from_memory (updater) """ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory -from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement # --------------------------------------------------------------------------- # Helpers @@ -31,7 +31,7 @@ def _ai(text: str, tool_calls=None) -> AIMessage: # =========================================================================== -# _filter_messages_for_memory +# filter_messages_for_memory # =========================================================================== @@ -45,7 +45,7 @@ class TestFilterMessagesForMemory: _human(_UPLOAD_BLOCK), _ai("I have read the file. It says: Hello."), ] - result = _filter_messages_for_memory(msgs) + result = filter_messages_for_memory(msgs) assert result == [] def test_upload_with_real_question_preserves_question(self): @@ -56,7 +56,7 @@ class TestFilterMessagesForMemory: _human(combined), _ai("The file contains: Hello DeerFlow."), ] - result = _filter_messages_for_memory(msgs) + result = filter_messages_for_memory(msgs) assert len(result) == 2 human_result = result[0] @@ -71,7 +71,7 @@ class TestFilterMessagesForMemory: _human("What is the capital of France?"), _ai("The capital of France is Paris."), ] - result = _filter_messages_for_memory(msgs) + result = filter_messages_for_memory(msgs) assert len(result) == 2 assert result[0].content == "What is the capital of France?" assert result[1].content == "The capital of France is Paris." @@ -84,7 +84,7 @@ class TestFilterMessagesForMemory: ToolMessage(content="Search results", tool_call_id="1"), _ai("Here are the results."), ] - result = _filter_messages_for_memory(msgs) + result = filter_messages_for_memory(msgs) human_msgs = [m for m in result if m.type == "human"] ai_msgs = [m for m in result if m.type == "ai"] assert len(human_msgs) == 1 @@ -101,7 +101,7 @@ class TestFilterMessagesForMemory: _human("What is 2 + 2?"), _ai("4"), ] - result = _filter_messages_for_memory(msgs) + result = filter_messages_for_memory(msgs) human_contents = [m.content for m in result if m.type == "human"] ai_contents = [m.content for m in result if m.type == "ai"] @@ -121,14 +121,14 @@ class TestFilterMessagesForMemory: ] ) msgs = [msg, _ai("Done.")] - result = _filter_messages_for_memory(msgs) + result = filter_messages_for_memory(msgs) assert result == [] def test_file_path_not_in_filtered_content(self): """After filtering, no upload file path should appear in any message.""" combined = _UPLOAD_BLOCK + "\n\nSummarise the file please." msgs = [_human(combined), _ai("It says hello.")] - result = _filter_messages_for_memory(msgs) + result = filter_messages_for_memory(msgs) all_content = " ".join(m.content for m in result if isinstance(m.content, str)) assert "/mnt/user-data/uploads/" not in all_content assert "" not in all_content diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py new file mode 100644 index 000000000..d327c94c4 --- /dev/null +++ b/backend/tests/test_summarization_middleware.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage + +from deerflow.agents.memory.summarization_hook import memory_flush_hook +from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent +from deerflow.config.memory_config import MemoryConfig + + +def _messages() -> list: + return [ + HumanMessage(content="user-1"), + AIMessage(content="assistant-1"), + HumanMessage(content="user-2"), + AIMessage(content="assistant-2"), + ] + + +def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace: + context = {} + if thread_id is not None: + context["thread_id"] = thread_id + if agent_name is not None: + context["agent_name"] = agent_name + return SimpleNamespace(context=context) + + +def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware: + model = MagicMock() + model.invoke.return_value = SimpleNamespace(text="compressed summary") + return DeerFlowSummarizationMiddleware( + model=model, + trigger=trigger, + keep=keep, + token_counter=len, + before_summarization=before_summarization, + ) + + +def test_before_summarization_hook_receives_messages_before_compression() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware(before_summarization=[captured.append]) + + result = middleware.before_model({"messages": _messages()}, _runtime()) + + assert len(captured) == 1 + assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"] + assert [message.content for message in captured[0].preserved_messages] == ["user-2", "assistant-2"] + assert captured[0].thread_id == "thread-1" + assert captured[0].agent_name is None + assert isinstance(result["messages"][0], RemoveMessage) + assert result["messages"][1].content.startswith("Here is a summary") + + +def test_before_summarization_hook_not_called_when_threshold_not_met() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10)) + + result = middleware.before_model({"messages": _messages()}, _runtime()) + + assert captured == [] + assert result is None + + +def test_before_summarization_hook_exception_does_not_block_compression(caplog: pytest.LogCaptureFixture) -> None: + def _broken_hook(_: SummarizationEvent) -> None: + raise RuntimeError("hook failure") + + middleware = _middleware(before_summarization=[_broken_hook]) + + with caplog.at_level("ERROR"): + result = middleware.before_model({"messages": _messages()}, _runtime()) + + assert "before_summarization hook _broken_hook failed" in caplog.text + assert isinstance(result["messages"][0], RemoveMessage) + + +def test_multiple_before_summarization_hooks_run_in_registration_order() -> None: + call_order: list[str] = [] + + def _hook(name: str): + return lambda _: call_order.append(name) + + middleware = _middleware(before_summarization=[_hook("first"), _hook("second"), _hook("third")]) + + middleware.before_model({"messages": _messages()}, _runtime()) + + assert call_order == ["first", "second", "third"] + + +@pytest.mark.anyio +async def test_abefore_model_calls_hooks_same_as_sync() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware(before_summarization=[captured.append]) + + await middleware.abefore_model({"messages": _messages()}, _runtime()) + + assert len(captured) == 1 + assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"] + + +def test_memory_flush_hook_skips_when_memory_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + queue = MagicMock() + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=False)) + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue) + + memory_flush_hook( + SummarizationEvent( + messages_to_summarize=tuple(_messages()[:2]), + preserved_messages=(), + thread_id="thread-1", + agent_name=None, + runtime=_runtime(), + ) + ) + + queue.add_nowait.assert_not_called() + + +def test_memory_flush_hook_skips_when_thread_id_missing(monkeypatch: pytest.MonkeyPatch) -> None: + queue = MagicMock() + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True)) + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue) + + memory_flush_hook( + SummarizationEvent( + messages_to_summarize=tuple(_messages()[:2]), + preserved_messages=(), + thread_id=None, + agent_name=None, + runtime=_runtime(None), + ) + ) + + queue.add_nowait.assert_not_called() + + +def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: pytest.MonkeyPatch) -> None: + queue = MagicMock() + messages = [ + HumanMessage(content="Question"), + AIMessage(content="Calling tool", tool_calls=[{"name": "search", "id": "tool-1", "args": {}}]), + AIMessage(content="Final answer"), + ] + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True)) + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue) + + memory_flush_hook( + SummarizationEvent( + messages_to_summarize=tuple(messages), + preserved_messages=(), + thread_id="thread-1", + agent_name=None, + runtime=_runtime(), + ) + ) + + queue.add_nowait.assert_called_once() + add_kwargs = queue.add_nowait.call_args.kwargs + assert add_kwargs["thread_id"] == "thread-1" + assert [message.content for message in add_kwargs["messages"]] == ["Question", "Final answer"] + assert add_kwargs["correction_detected"] is False + assert add_kwargs["reinforcement_detected"] is False + + +def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None: + queue = MagicMock() + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True)) + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue) + + memory_flush_hook( + SummarizationEvent( + messages_to_summarize=tuple(_messages()[:2]), + preserved_messages=(), + thread_id="thread-1", + agent_name="research-agent", + runtime=_runtime(agent_name="research-agent"), + ) + ) + + queue.add_nowait.assert_called_once() + assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"