mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat: flush memory before summarization (#2176)
* feat: flush memory before summarization * fix: keep agent-scoped memory on summarization flush * fix: harden summarization hook plumbing * fix: address summarization review feedback * style: format memory middleware
This commit is contained in:
parent
e4f896e90d
commit
4ba3167f48
@ -1,14 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||||
|
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||||
|
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
|
||||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||||
@ -17,6 +19,7 @@ from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddlewar
|
|||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
from deerflow.config.agents_config import load_agent_config
|
from deerflow.config.agents_config import load_agent_config
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.config.summarization_config import get_summarization_config
|
from deerflow.config.summarization_config import get_summarization_config
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
@ -38,7 +41,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
|||||||
return default_model_name
|
return default_model_name
|
||||||
|
|
||||||
|
|
||||||
def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
|
||||||
"""Create and configure the summarization middleware from config."""
|
"""Create and configure the summarization middleware from config."""
|
||||||
config = get_summarization_config()
|
config = get_summarization_config()
|
||||||
|
|
||||||
@ -77,7 +80,11 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
|||||||
if config.summary_prompt is not None:
|
if config.summary_prompt is not None:
|
||||||
kwargs["summary_prompt"] = config.summary_prompt
|
kwargs["summary_prompt"] = config.summary_prompt
|
||||||
|
|
||||||
return SummarizationMiddleware(**kwargs)
|
hooks: list[BeforeSummarizationHook] = []
|
||||||
|
if get_memory_config().enabled:
|
||||||
|
hooks.append(memory_flush_hook)
|
||||||
|
|
||||||
|
return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks)
|
||||||
|
|
||||||
|
|
||||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
||||||
|
|||||||
@ -0,0 +1,109 @@
|
|||||||
|
"""Shared helpers for turning conversations into memory update inputs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from copy import copy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||||
|
_CORRECTION_PATTERNS = (
|
||||||
|
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||||
|
re.compile(r"不对"),
|
||||||
|
re.compile(r"你理解错了"),
|
||||||
|
re.compile(r"你理解有误"),
|
||||||
|
re.compile(r"重试"),
|
||||||
|
re.compile(r"重新来"),
|
||||||
|
re.compile(r"换一种"),
|
||||||
|
re.compile(r"改用"),
|
||||||
|
)
|
||||||
|
_REINFORCEMENT_PATTERNS = (
|
||||||
|
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||||
|
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||||
|
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_text(message: Any) -> str:
|
||||||
|
"""Extract plain text from message content for filtering and signal detection."""
|
||||||
|
content = getattr(message, "content", "")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, str):
|
||||||
|
text_parts.append(part)
|
||||||
|
elif isinstance(part, dict):
|
||||||
|
text_val = part.get("text")
|
||||||
|
if isinstance(text_val, str):
|
||||||
|
text_parts.append(text_val)
|
||||||
|
return " ".join(text_parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||||
|
"""Keep only user inputs and final assistant responses for memory updates."""
|
||||||
|
filtered = []
|
||||||
|
skip_next_ai = False
|
||||||
|
for msg in messages:
|
||||||
|
msg_type = getattr(msg, "type", None)
|
||||||
|
|
||||||
|
if msg_type == "human":
|
||||||
|
content_str = extract_message_text(msg)
|
||||||
|
if "<uploaded_files>" in content_str:
|
||||||
|
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||||
|
if not stripped:
|
||||||
|
skip_next_ai = True
|
||||||
|
continue
|
||||||
|
clean_msg = copy(msg)
|
||||||
|
clean_msg.content = stripped
|
||||||
|
filtered.append(clean_msg)
|
||||||
|
skip_next_ai = False
|
||||||
|
else:
|
||||||
|
filtered.append(msg)
|
||||||
|
skip_next_ai = False
|
||||||
|
elif msg_type == "ai":
|
||||||
|
tool_calls = getattr(msg, "tool_calls", None)
|
||||||
|
if not tool_calls:
|
||||||
|
if skip_next_ai:
|
||||||
|
skip_next_ai = False
|
||||||
|
continue
|
||||||
|
filtered.append(msg)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def detect_correction(messages: list[Any]) -> bool:
|
||||||
|
"""Detect explicit user corrections in recent conversation turns."""
|
||||||
|
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||||
|
|
||||||
|
for msg in recent_user_msgs:
|
||||||
|
content = extract_message_text(msg).strip()
|
||||||
|
if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||||
|
"""Detect explicit positive reinforcement signals in recent conversation turns."""
|
||||||
|
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||||
|
|
||||||
|
for msg in recent_user_msgs:
|
||||||
|
content = extract_message_text(msg).strip()
|
||||||
|
if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
@ -61,48 +61,88 @@ class MemoryUpdateQueue:
|
|||||||
return
|
return
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
existing_context = next(
|
self._enqueue_locked(
|
||||||
(context for context in self._queue if context.thread_id == thread_id),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
|
||||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
|
||||||
context = ConversationContext(
|
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
correction_detected=merged_correction_detected,
|
correction_detected=correction_detected,
|
||||||
reinforcement_detected=merged_reinforcement_detected,
|
reinforcement_detected=reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if this thread already has a pending update
|
|
||||||
# If so, replace it with the newer one
|
|
||||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
|
||||||
self._queue.append(context)
|
|
||||||
|
|
||||||
# Reset or start the debounce timer
|
|
||||||
self._reset_timer()
|
self._reset_timer()
|
||||||
|
|
||||||
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
|
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
|
||||||
|
|
||||||
|
def add_nowait(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
messages: list[Any],
|
||||||
|
agent_name: str | None = None,
|
||||||
|
correction_detected: bool = False,
|
||||||
|
reinforcement_detected: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Add a conversation and start processing immediately in the background."""
|
||||||
|
config = get_memory_config()
|
||||||
|
if not config.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._enqueue_locked(
|
||||||
|
thread_id=thread_id,
|
||||||
|
messages=messages,
|
||||||
|
agent_name=agent_name,
|
||||||
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
|
)
|
||||||
|
self._schedule_timer(0)
|
||||||
|
|
||||||
|
logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue))
|
||||||
|
|
||||||
|
def _enqueue_locked(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
messages: list[Any],
|
||||||
|
agent_name: str | None,
|
||||||
|
correction_detected: bool,
|
||||||
|
reinforcement_detected: bool,
|
||||||
|
) -> None:
|
||||||
|
existing_context = next(
|
||||||
|
(context for context in self._queue if context.thread_id == thread_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||||
|
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||||
|
context = ConversationContext(
|
||||||
|
thread_id=thread_id,
|
||||||
|
messages=messages,
|
||||||
|
agent_name=agent_name,
|
||||||
|
correction_detected=merged_correction_detected,
|
||||||
|
reinforcement_detected=merged_reinforcement_detected,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||||
|
self._queue.append(context)
|
||||||
|
|
||||||
def _reset_timer(self) -> None:
|
def _reset_timer(self) -> None:
|
||||||
"""Reset the debounce timer."""
|
"""Reset the debounce timer."""
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
|
self._schedule_timer(config.debounce_seconds)
|
||||||
|
|
||||||
|
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
|
||||||
|
|
||||||
|
def _schedule_timer(self, delay_seconds: float) -> None:
|
||||||
|
"""Schedule queue processing after the provided delay."""
|
||||||
# Cancel existing timer if any
|
# Cancel existing timer if any
|
||||||
if self._timer is not None:
|
if self._timer is not None:
|
||||||
self._timer.cancel()
|
self._timer.cancel()
|
||||||
|
|
||||||
# Start new timer
|
|
||||||
self._timer = threading.Timer(
|
self._timer = threading.Timer(
|
||||||
config.debounce_seconds,
|
delay_seconds,
|
||||||
self._process_queue,
|
self._process_queue,
|
||||||
)
|
)
|
||||||
self._timer.daemon = True
|
self._timer.daemon = True
|
||||||
self._timer.start()
|
self._timer.start()
|
||||||
|
|
||||||
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
|
|
||||||
|
|
||||||
def _process_queue(self) -> None:
|
def _process_queue(self) -> None:
|
||||||
"""Process all queued conversation contexts."""
|
"""Process all queued conversation contexts."""
|
||||||
# Import here to avoid circular dependency
|
# Import here to avoid circular dependency
|
||||||
@ -110,8 +150,8 @@ class MemoryUpdateQueue:
|
|||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._processing:
|
if self._processing:
|
||||||
# Already processing, reschedule
|
# Preserve immediate flush semantics even if another worker is active.
|
||||||
self._reset_timer()
|
self._schedule_timer(0)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self._queue:
|
if not self._queue:
|
||||||
@ -164,6 +204,13 @@ class MemoryUpdateQueue:
|
|||||||
|
|
||||||
self._process_queue()
|
self._process_queue()
|
||||||
|
|
||||||
|
def flush_nowait(self) -> None:
|
||||||
|
"""Start queue processing immediately in a background thread."""
|
||||||
|
with self._lock:
|
||||||
|
# Daemon thread: queued messages may be lost if the process exits
|
||||||
|
# before _process_queue completes. Acceptable for best-effort memory updates.
|
||||||
|
self._schedule_timer(0)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear the queue without processing.
|
"""Clear the queue without processing.
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,31 @@
|
|||||||
|
"""Hooks fired before summarization removes messages from state."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||||
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
|
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
||||||
|
from deerflow.config.memory_config import get_memory_config
|
||||||
|
|
||||||
|
|
||||||
|
def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||||
|
"""Flush messages about to be summarized into the memory queue."""
|
||||||
|
if not get_memory_config().enabled or not event.thread_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
|
||||||
|
user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"]
|
||||||
|
assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"]
|
||||||
|
if not user_messages or not assistant_messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
correction_detected = detect_correction(filtered_messages)
|
||||||
|
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||||
|
queue = get_memory_queue()
|
||||||
|
queue.add_nowait(
|
||||||
|
thread_id=event.thread_id,
|
||||||
|
messages=filtered_messages,
|
||||||
|
agent_name=event.agent_name,
|
||||||
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
|
)
|
||||||
@ -1,50 +1,19 @@
|
|||||||
"""Middleware for memory mechanism."""
|
"""Middleware for memory mechanism."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
from typing import override
|
||||||
from typing import Any, override
|
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||||
from deerflow.agents.memory.queue import get_memory_queue
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
|
||||||
_CORRECTION_PATTERNS = (
|
|
||||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
|
||||||
re.compile(r"不对"),
|
|
||||||
re.compile(r"你理解错了"),
|
|
||||||
re.compile(r"你理解有误"),
|
|
||||||
re.compile(r"重试"),
|
|
||||||
re.compile(r"重新来"),
|
|
||||||
re.compile(r"换一种"),
|
|
||||||
re.compile(r"改用"),
|
|
||||||
)
|
|
||||||
|
|
||||||
_REINFORCEMENT_PATTERNS = (
|
|
||||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
|
||||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
|
||||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
|
||||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
|
||||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
|
||||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
|
||||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
|
||||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
|
||||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddlewareState(AgentState):
|
class MemoryMiddlewareState(AgentState):
|
||||||
"""Compatible with the `ThreadState` schema."""
|
"""Compatible with the `ThreadState` schema."""
|
||||||
@ -52,125 +21,6 @@ class MemoryMiddlewareState(AgentState):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _extract_message_text(message: Any) -> str:
|
|
||||||
"""Extract plain text from message content for filtering and signal detection."""
|
|
||||||
content = getattr(message, "content", "")
|
|
||||||
if isinstance(content, list):
|
|
||||||
text_parts: list[str] = []
|
|
||||||
for part in content:
|
|
||||||
if isinstance(part, str):
|
|
||||||
text_parts.append(part)
|
|
||||||
elif isinstance(part, dict):
|
|
||||||
text_val = part.get("text")
|
|
||||||
if isinstance(text_val, str):
|
|
||||||
text_parts.append(text_val)
|
|
||||||
return " ".join(text_parts)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
|
||||||
"""Filter messages to keep only user inputs and final assistant responses.
|
|
||||||
|
|
||||||
This filters out:
|
|
||||||
- Tool messages (intermediate tool call results)
|
|
||||||
- AI messages with tool_calls (intermediate steps, not final responses)
|
|
||||||
- The <uploaded_files> block injected by UploadsMiddleware into human messages
|
|
||||||
(file paths are session-scoped and must not persist in long-term memory).
|
|
||||||
The user's actual question is preserved; only turns whose content is entirely
|
|
||||||
the upload block (nothing remains after stripping) are dropped along with
|
|
||||||
their paired assistant response.
|
|
||||||
|
|
||||||
Only keeps:
|
|
||||||
- Human messages (with the ephemeral upload block removed)
|
|
||||||
- AI messages without tool_calls (final assistant responses), unless the
|
|
||||||
paired human turn was upload-only and had no real user text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of all conversation messages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Filtered list containing only user inputs and final assistant responses.
|
|
||||||
"""
|
|
||||||
filtered = []
|
|
||||||
skip_next_ai = False
|
|
||||||
for msg in messages:
|
|
||||||
msg_type = getattr(msg, "type", None)
|
|
||||||
|
|
||||||
if msg_type == "human":
|
|
||||||
content_str = _extract_message_text(msg)
|
|
||||||
if "<uploaded_files>" in content_str:
|
|
||||||
# Strip the ephemeral upload block; keep the user's real question.
|
|
||||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
|
||||||
if not stripped:
|
|
||||||
# Nothing left — the entire turn was upload bookkeeping;
|
|
||||||
# skip it and the paired assistant response.
|
|
||||||
skip_next_ai = True
|
|
||||||
continue
|
|
||||||
# Rebuild the message with cleaned content so the user's question
|
|
||||||
# is still available for memory summarisation.
|
|
||||||
from copy import copy
|
|
||||||
|
|
||||||
clean_msg = copy(msg)
|
|
||||||
clean_msg.content = stripped
|
|
||||||
filtered.append(clean_msg)
|
|
||||||
skip_next_ai = False
|
|
||||||
else:
|
|
||||||
filtered.append(msg)
|
|
||||||
skip_next_ai = False
|
|
||||||
elif msg_type == "ai":
|
|
||||||
tool_calls = getattr(msg, "tool_calls", None)
|
|
||||||
if not tool_calls:
|
|
||||||
if skip_next_ai:
|
|
||||||
skip_next_ai = False
|
|
||||||
continue
|
|
||||||
filtered.append(msg)
|
|
||||||
# Skip tool messages and AI messages with tool_calls
|
|
||||||
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
|
|
||||||
def detect_correction(messages: list[Any]) -> bool:
|
|
||||||
"""Detect explicit user corrections in recent conversation turns.
|
|
||||||
|
|
||||||
The queue keeps only one pending context per thread, so callers pass the
|
|
||||||
latest filtered message list. Checking only recent user turns keeps signal
|
|
||||||
detection conservative while avoiding stale corrections from long histories.
|
|
||||||
"""
|
|
||||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
|
||||||
|
|
||||||
for msg in recent_user_msgs:
|
|
||||||
content = _extract_message_text(msg).strip()
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
|
||||||
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
|
||||||
|
|
||||||
Complements detect_correction() by identifying when the user confirms the
|
|
||||||
agent's approach was correct. This allows the memory system to record what
|
|
||||||
worked well, not just what went wrong.
|
|
||||||
|
|
||||||
The queue keeps only one pending context per thread, so callers pass the
|
|
||||||
latest filtered message list. Checking only recent user turns keeps signal
|
|
||||||
detection conservative while avoiding stale signals from long histories.
|
|
||||||
"""
|
|
||||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
|
||||||
|
|
||||||
for msg in recent_user_msgs:
|
|
||||||
content = _extract_message_text(msg).strip()
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||||
"""Middleware that queues conversation for memory update after agent execution.
|
"""Middleware that queues conversation for memory update after agent execution.
|
||||||
|
|
||||||
@ -223,7 +73,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Filter to only keep user inputs and final assistant responses
|
# Filter to only keep user inputs and final assistant responses
|
||||||
filtered_messages = _filter_messages_for_memory(messages)
|
filtered_messages = filter_messages_for_memory(messages)
|
||||||
|
|
||||||
# Only queue if there's meaningful conversation
|
# Only queue if there's meaningful conversation
|
||||||
# At minimum need one user message and one assistant response
|
# At minimum need one user message and one assistant response
|
||||||
|
|||||||
@ -0,0 +1,151 @@
|
|||||||
|
"""Summarization middleware extensions for DeerFlow."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from langchain.agents import AgentState
|
||||||
|
from langchain.agents.middleware import SummarizationMiddleware
|
||||||
|
from langchain_core.messages import AnyMessage, RemoveMessage
|
||||||
|
from langgraph.config import get_config
|
||||||
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SummarizationEvent:
|
||||||
|
"""Context emitted before conversation history is summarized away."""
|
||||||
|
|
||||||
|
messages_to_summarize: tuple[AnyMessage, ...]
|
||||||
|
preserved_messages: tuple[AnyMessage, ...]
|
||||||
|
thread_id: str | None
|
||||||
|
agent_name: str | None
|
||||||
|
runtime: Runtime
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class BeforeSummarizationHook(Protocol):
|
||||||
|
"""Hook invoked before summarization removes messages from state."""
|
||||||
|
|
||||||
|
def __call__(self, event: SummarizationEvent) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_thread_id(runtime: Runtime) -> str | None:
|
||||||
|
"""Resolve the current thread ID from runtime context or LangGraph config."""
|
||||||
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||||
|
if thread_id is None:
|
||||||
|
try:
|
||||||
|
config_data = get_config()
|
||||||
|
except RuntimeError:
|
||||||
|
return None
|
||||||
|
thread_id = config_data.get("configurable", {}).get("thread_id")
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
||||||
|
"""Resolve the current agent name from runtime context or LangGraph config."""
|
||||||
|
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
||||||
|
if agent_name is None:
|
||||||
|
try:
|
||||||
|
config_data = get_config()
|
||||||
|
except RuntimeError:
|
||||||
|
return None
|
||||||
|
agent_name = config_data.get("configurable", {}).get("agent_name")
|
||||||
|
return agent_name
|
||||||
|
|
||||||
|
|
||||||
|
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||||
|
"""Summarization middleware with pre-compression hook dispatch."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
before_summarization: list[BeforeSummarizationHook] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._before_summarization_hooks = before_summarization or []
|
||||||
|
|
||||||
|
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
|
return self._maybe_summarize(state, runtime)
|
||||||
|
|
||||||
|
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
|
return await self._amaybe_summarize(state, runtime)
|
||||||
|
|
||||||
|
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
|
messages = state["messages"]
|
||||||
|
self._ensure_message_ids(messages)
|
||||||
|
|
||||||
|
total_tokens = self.token_counter(messages)
|
||||||
|
if not self._should_summarize(messages, total_tokens):
|
||||||
|
return None
|
||||||
|
|
||||||
|
cutoff_index = self._determine_cutoff_index(messages)
|
||||||
|
if cutoff_index <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||||
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||||
|
summary = self._create_summary(messages_to_summarize)
|
||||||
|
new_messages = self._build_new_messages(summary)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||||
|
*new_messages,
|
||||||
|
*preserved_messages,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
|
messages = state["messages"]
|
||||||
|
self._ensure_message_ids(messages)
|
||||||
|
|
||||||
|
total_tokens = self.token_counter(messages)
|
||||||
|
if not self._should_summarize(messages, total_tokens):
|
||||||
|
return None
|
||||||
|
|
||||||
|
cutoff_index = self._determine_cutoff_index(messages)
|
||||||
|
if cutoff_index <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||||
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||||
|
summary = await self._acreate_summary(messages_to_summarize)
|
||||||
|
new_messages = self._build_new_messages(summary)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||||
|
*new_messages,
|
||||||
|
*preserved_messages,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _fire_hooks(
|
||||||
|
self,
|
||||||
|
messages_to_summarize: list[AnyMessage],
|
||||||
|
preserved_messages: list[AnyMessage],
|
||||||
|
runtime: Runtime,
|
||||||
|
) -> None:
|
||||||
|
if not self._before_summarization_hooks:
|
||||||
|
return
|
||||||
|
|
||||||
|
event = SummarizationEvent(
|
||||||
|
messages_to_summarize=tuple(messages_to_summarize),
|
||||||
|
preserved_messages=tuple(preserved_messages),
|
||||||
|
thread_id=_resolve_thread_id(runtime),
|
||||||
|
agent_name=_resolve_agent_name(runtime),
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
for hook in self._before_summarization_hooks:
|
||||||
|
try:
|
||||||
|
hook(event)
|
||||||
|
except Exception:
|
||||||
|
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
|
||||||
|
logger.exception("before_summarization hook %s failed", hook_name)
|
||||||
@ -8,6 +8,7 @@ import pytest
|
|||||||
|
|
||||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||||
from deerflow.config.app_config import AppConfig
|
from deerflow.config.app_config import AppConfig
|
||||||
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
from deerflow.config.model_config import ModelConfig
|
from deerflow.config.model_config import ModelConfig
|
||||||
from deerflow.config.sandbox_config import SandboxConfig
|
from deerflow.config.sandbox_config import SandboxConfig
|
||||||
from deerflow.config.summarization_config import SummarizationConfig
|
from deerflow.config.summarization_config import SummarizationConfig
|
||||||
@ -145,6 +146,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
|||||||
"get_summarization_config",
|
"get_summarization_config",
|
||||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||||
)
|
)
|
||||||
|
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||||
|
|
||||||
captured: dict[str, object] = {}
|
captured: dict[str, object] = {}
|
||||||
fake_model = object()
|
fake_model = object()
|
||||||
@ -156,10 +158,32 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
|||||||
return fake_model
|
return fake_model
|
||||||
|
|
||||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||||
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
|
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
|
||||||
|
|
||||||
middleware = lead_agent_module._create_summarization_middleware()
|
middleware = lead_agent_module._create_summarization_middleware()
|
||||||
|
|
||||||
assert captured["name"] == "model-masswork"
|
assert captured["name"] == "model-masswork"
|
||||||
assert captured["thinking_enabled"] is False
|
assert captured["thinking_enabled"] is False
|
||||||
assert middleware["model"] is fake_model
|
assert middleware["model"] is fake_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_summarization_middleware_registers_memory_flush_hook_when_memory_enabled(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
lead_agent_module,
|
||||||
|
"get_summarization_config",
|
||||||
|
lambda: SummarizationConfig(enabled=True),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||||
|
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
|
||||||
|
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
def _fake_middleware(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
|
||||||
|
|
||||||
|
lead_agent_module._create_summarization_middleware()
|
||||||
|
|
||||||
|
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||||
@ -89,3 +91,74 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
|||||||
correction_detected=False,
|
correction_detected=False,
|
||||||
reinforcement_detected=True,
|
reinforcement_detected=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_flush_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
existing_timer = MagicMock()
|
||||||
|
queue._timer = existing_timer
|
||||||
|
created_timer = MagicMock()
|
||||||
|
|
||||||
|
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
|
||||||
|
queue.flush_nowait()
|
||||||
|
|
||||||
|
existing_timer.cancel.assert_called_once_with()
|
||||||
|
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||||
|
assert created_timer.daemon is True
|
||||||
|
created_timer.start.assert_called_once_with()
|
||||||
|
assert queue._timer is created_timer
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
existing_timer = MagicMock()
|
||||||
|
queue._timer = existing_timer
|
||||||
|
created_timer = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls,
|
||||||
|
):
|
||||||
|
queue.add_nowait(thread_id="thread-1", messages=["conversation"], agent_name="lead-agent")
|
||||||
|
|
||||||
|
existing_timer.cancel.assert_called_once_with()
|
||||||
|
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||||
|
assert queue.pending_count == 1
|
||||||
|
assert queue._queue[0].agent_name == "lead-agent"
|
||||||
|
assert created_timer.daemon is True
|
||||||
|
created_timer.start.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_queue_reschedules_immediately_when_already_processing() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
queue._processing = True
|
||||||
|
created_timer = MagicMock()
|
||||||
|
|
||||||
|
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
|
||||||
|
queue._process_queue()
|
||||||
|
|
||||||
|
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||||
|
assert created_timer.daemon is True
|
||||||
|
created_timer.start.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
|
def test_flush_nowait_is_non_blocking() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
started = threading.Event()
|
||||||
|
finished = threading.Event()
|
||||||
|
|
||||||
|
def _slow_process_queue() -> None:
|
||||||
|
started.set()
|
||||||
|
time.sleep(0.2)
|
||||||
|
finished.set()
|
||||||
|
|
||||||
|
queue._process_queue = _slow_process_queue
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
queue.flush_nowait()
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
|
||||||
|
assert started.wait(0.1) is True
|
||||||
|
assert elapsed < 0.1
|
||||||
|
assert finished.is_set() is False
|
||||||
|
assert finished.wait(1.0) is True
|
||||||
|
|||||||
@ -3,14 +3,14 @@
|
|||||||
Covers two functions introduced to prevent ephemeral file-upload context from
|
Covers two functions introduced to prevent ephemeral file-upload context from
|
||||||
persisting in long-term memory:
|
persisting in long-term memory:
|
||||||
|
|
||||||
- _filter_messages_for_memory (memory_middleware)
|
- filter_messages_for_memory (message_processing)
|
||||||
- _strip_upload_mentions_from_memory (updater)
|
- _strip_upload_mentions_from_memory (updater)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
|
|
||||||
|
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
@ -31,7 +31,7 @@ def _ai(text: str, tool_calls=None) -> AIMessage:
|
|||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# _filter_messages_for_memory
|
# filter_messages_for_memory
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class TestFilterMessagesForMemory:
|
|||||||
_human(_UPLOAD_BLOCK),
|
_human(_UPLOAD_BLOCK),
|
||||||
_ai("I have read the file. It says: Hello."),
|
_ai("I have read the file. It says: Hello."),
|
||||||
]
|
]
|
||||||
result = _filter_messages_for_memory(msgs)
|
result = filter_messages_for_memory(msgs)
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
def test_upload_with_real_question_preserves_question(self):
|
def test_upload_with_real_question_preserves_question(self):
|
||||||
@ -56,7 +56,7 @@ class TestFilterMessagesForMemory:
|
|||||||
_human(combined),
|
_human(combined),
|
||||||
_ai("The file contains: Hello DeerFlow."),
|
_ai("The file contains: Hello DeerFlow."),
|
||||||
]
|
]
|
||||||
result = _filter_messages_for_memory(msgs)
|
result = filter_messages_for_memory(msgs)
|
||||||
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
human_result = result[0]
|
human_result = result[0]
|
||||||
@ -71,7 +71,7 @@ class TestFilterMessagesForMemory:
|
|||||||
_human("What is the capital of France?"),
|
_human("What is the capital of France?"),
|
||||||
_ai("The capital of France is Paris."),
|
_ai("The capital of France is Paris."),
|
||||||
]
|
]
|
||||||
result = _filter_messages_for_memory(msgs)
|
result = filter_messages_for_memory(msgs)
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result[0].content == "What is the capital of France?"
|
assert result[0].content == "What is the capital of France?"
|
||||||
assert result[1].content == "The capital of France is Paris."
|
assert result[1].content == "The capital of France is Paris."
|
||||||
@ -84,7 +84,7 @@ class TestFilterMessagesForMemory:
|
|||||||
ToolMessage(content="Search results", tool_call_id="1"),
|
ToolMessage(content="Search results", tool_call_id="1"),
|
||||||
_ai("Here are the results."),
|
_ai("Here are the results."),
|
||||||
]
|
]
|
||||||
result = _filter_messages_for_memory(msgs)
|
result = filter_messages_for_memory(msgs)
|
||||||
human_msgs = [m for m in result if m.type == "human"]
|
human_msgs = [m for m in result if m.type == "human"]
|
||||||
ai_msgs = [m for m in result if m.type == "ai"]
|
ai_msgs = [m for m in result if m.type == "ai"]
|
||||||
assert len(human_msgs) == 1
|
assert len(human_msgs) == 1
|
||||||
@ -101,7 +101,7 @@ class TestFilterMessagesForMemory:
|
|||||||
_human("What is 2 + 2?"),
|
_human("What is 2 + 2?"),
|
||||||
_ai("4"),
|
_ai("4"),
|
||||||
]
|
]
|
||||||
result = _filter_messages_for_memory(msgs)
|
result = filter_messages_for_memory(msgs)
|
||||||
human_contents = [m.content for m in result if m.type == "human"]
|
human_contents = [m.content for m in result if m.type == "human"]
|
||||||
ai_contents = [m.content for m in result if m.type == "ai"]
|
ai_contents = [m.content for m in result if m.type == "ai"]
|
||||||
|
|
||||||
@ -121,14 +121,14 @@ class TestFilterMessagesForMemory:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
msgs = [msg, _ai("Done.")]
|
msgs = [msg, _ai("Done.")]
|
||||||
result = _filter_messages_for_memory(msgs)
|
result = filter_messages_for_memory(msgs)
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
def test_file_path_not_in_filtered_content(self):
|
def test_file_path_not_in_filtered_content(self):
|
||||||
"""After filtering, no upload file path should appear in any message."""
|
"""After filtering, no upload file path should appear in any message."""
|
||||||
combined = _UPLOAD_BLOCK + "\n\nSummarise the file please."
|
combined = _UPLOAD_BLOCK + "\n\nSummarise the file please."
|
||||||
msgs = [_human(combined), _ai("It says hello.")]
|
msgs = [_human(combined), _ai("It says hello.")]
|
||||||
result = _filter_messages_for_memory(msgs)
|
result = filter_messages_for_memory(msgs)
|
||||||
all_content = " ".join(m.content for m in result if isinstance(m.content, str))
|
all_content = " ".join(m.content for m in result if isinstance(m.content, str))
|
||||||
assert "/mnt/user-data/uploads/" not in all_content
|
assert "/mnt/user-data/uploads/" not in all_content
|
||||||
assert "<uploaded_files>" not in all_content
|
assert "<uploaded_files>" not in all_content
|
||||||
|
|||||||
186
backend/tests/test_summarization_middleware.py
Normal file
186
backend/tests/test_summarization_middleware.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
|
||||||
|
|
||||||
|
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||||
|
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
|
||||||
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _messages() -> list:
|
||||||
|
return [
|
||||||
|
HumanMessage(content="user-1"),
|
||||||
|
AIMessage(content="assistant-1"),
|
||||||
|
HumanMessage(content="user-2"),
|
||||||
|
AIMessage(content="assistant-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||||
|
context = {}
|
||||||
|
if thread_id is not None:
|
||||||
|
context["thread_id"] = thread_id
|
||||||
|
if agent_name is not None:
|
||||||
|
context["agent_name"] = agent_name
|
||||||
|
return SimpleNamespace(context=context)
|
||||||
|
|
||||||
|
|
||||||
|
def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware:
|
||||||
|
model = MagicMock()
|
||||||
|
model.invoke.return_value = SimpleNamespace(text="compressed summary")
|
||||||
|
return DeerFlowSummarizationMiddleware(
|
||||||
|
model=model,
|
||||||
|
trigger=trigger,
|
||||||
|
keep=keep,
|
||||||
|
token_counter=len,
|
||||||
|
before_summarization=before_summarization,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_before_summarization_hook_receives_messages_before_compression() -> None:
|
||||||
|
captured: list[SummarizationEvent] = []
|
||||||
|
middleware = _middleware(before_summarization=[captured.append])
|
||||||
|
|
||||||
|
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||||
|
|
||||||
|
assert len(captured) == 1
|
||||||
|
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
|
||||||
|
assert [message.content for message in captured[0].preserved_messages] == ["user-2", "assistant-2"]
|
||||||
|
assert captured[0].thread_id == "thread-1"
|
||||||
|
assert captured[0].agent_name is None
|
||||||
|
assert isinstance(result["messages"][0], RemoveMessage)
|
||||||
|
assert result["messages"][1].content.startswith("Here is a summary")
|
||||||
|
|
||||||
|
|
||||||
|
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
|
||||||
|
captured: list[SummarizationEvent] = []
|
||||||
|
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
|
||||||
|
|
||||||
|
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||||
|
|
||||||
|
assert captured == []
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_before_summarization_hook_exception_does_not_block_compression(caplog: pytest.LogCaptureFixture) -> None:
|
||||||
|
def _broken_hook(_: SummarizationEvent) -> None:
|
||||||
|
raise RuntimeError("hook failure")
|
||||||
|
|
||||||
|
middleware = _middleware(before_summarization=[_broken_hook])
|
||||||
|
|
||||||
|
with caplog.at_level("ERROR"):
|
||||||
|
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||||
|
|
||||||
|
assert "before_summarization hook _broken_hook failed" in caplog.text
|
||||||
|
assert isinstance(result["messages"][0], RemoveMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_before_summarization_hooks_run_in_registration_order() -> None:
|
||||||
|
call_order: list[str] = []
|
||||||
|
|
||||||
|
def _hook(name: str):
|
||||||
|
return lambda _: call_order.append(name)
|
||||||
|
|
||||||
|
middleware = _middleware(before_summarization=[_hook("first"), _hook("second"), _hook("third")])
|
||||||
|
|
||||||
|
middleware.before_model({"messages": _messages()}, _runtime())
|
||||||
|
|
||||||
|
assert call_order == ["first", "second", "third"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_abefore_model_calls_hooks_same_as_sync() -> None:
|
||||||
|
captured: list[SummarizationEvent] = []
|
||||||
|
middleware = _middleware(before_summarization=[captured.append])
|
||||||
|
|
||||||
|
await middleware.abefore_model({"messages": _messages()}, _runtime())
|
||||||
|
|
||||||
|
assert len(captured) == 1
|
||||||
|
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_flush_hook_skips_when_memory_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
queue = MagicMock()
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||||
|
|
||||||
|
memory_flush_hook(
|
||||||
|
SummarizationEvent(
|
||||||
|
messages_to_summarize=tuple(_messages()[:2]),
|
||||||
|
preserved_messages=(),
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_name=None,
|
||||||
|
runtime=_runtime(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
queue.add_nowait.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_flush_hook_skips_when_thread_id_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
queue = MagicMock()
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||||
|
|
||||||
|
memory_flush_hook(
|
||||||
|
SummarizationEvent(
|
||||||
|
messages_to_summarize=tuple(_messages()[:2]),
|
||||||
|
preserved_messages=(),
|
||||||
|
thread_id=None,
|
||||||
|
agent_name=None,
|
||||||
|
runtime=_runtime(None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
queue.add_nowait.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
queue = MagicMock()
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="Question"),
|
||||||
|
AIMessage(content="Calling tool", tool_calls=[{"name": "search", "id": "tool-1", "args": {}}]),
|
||||||
|
AIMessage(content="Final answer"),
|
||||||
|
]
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||||
|
|
||||||
|
memory_flush_hook(
|
||||||
|
SummarizationEvent(
|
||||||
|
messages_to_summarize=tuple(messages),
|
||||||
|
preserved_messages=(),
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_name=None,
|
||||||
|
runtime=_runtime(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
queue.add_nowait.assert_called_once()
|
||||||
|
add_kwargs = queue.add_nowait.call_args.kwargs
|
||||||
|
assert add_kwargs["thread_id"] == "thread-1"
|
||||||
|
assert [message.content for message in add_kwargs["messages"]] == ["Question", "Final answer"]
|
||||||
|
assert add_kwargs["correction_detected"] is False
|
||||||
|
assert add_kwargs["reinforcement_detected"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
queue = MagicMock()
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||||
|
|
||||||
|
memory_flush_hook(
|
||||||
|
SummarizationEvent(
|
||||||
|
messages_to_summarize=tuple(_messages()[:2]),
|
||||||
|
preserved_messages=(),
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_name="research-agent",
|
||||||
|
runtime=_runtime(agent_name="research-agent"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
queue.add_nowait.assert_called_once()
|
||||||
|
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||||
Loading…
x
Reference in New Issue
Block a user