mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat: flush memory before summarization (#2176)
* feat: flush memory before summarization * fix: keep agent-scoped memory on summarization flush * fix: harden summarization hook plumbing * fix: address summarization review feedback * style: format memory middleware
This commit is contained in:
parent
e4f896e90d
commit
4ba3167f48
@ -1,14 +1,16 @@
|
||||
import logging
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||
@ -17,6 +19,7 @@ from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddlewar
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.summarization_config import get_summarization_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
@ -38,7 +41,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
return default_model_name
|
||||
|
||||
|
||||
def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
|
||||
"""Create and configure the summarization middleware from config."""
|
||||
config = get_summarization_config()
|
||||
|
||||
@ -77,7 +80,11 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
if config.summary_prompt is not None:
|
||||
kwargs["summary_prompt"] = config.summary_prompt
|
||||
|
||||
return SummarizationMiddleware(**kwargs)
|
||||
hooks: list[BeforeSummarizationHook] = []
|
||||
if get_memory_config().enabled:
|
||||
hooks.append(memory_flush_hook)
|
||||
|
||||
return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks)
|
||||
|
||||
|
||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
||||
|
||||
@ -0,0 +1,109 @@
|
||||
"""Shared helpers for turning conversations into memory update inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
_CORRECTION_PATTERNS = (
|
||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||
re.compile(r"不对"),
|
||||
re.compile(r"你理解错了"),
|
||||
re.compile(r"你理解有误"),
|
||||
re.compile(r"重试"),
|
||||
re.compile(r"重新来"),
|
||||
re.compile(r"换一种"),
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
_REINFORCEMENT_PATTERNS = (
|
||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||
)
|
||||
|
||||
|
||||
def extract_message_text(message: Any) -> str:
|
||||
"""Extract plain text from message content for filtering and signal detection."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text_val = part.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Keep only user inputs and final assistant responses for memory updates."""
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content_str = extract_message_text(msg)
|
||||
if "<uploaded_files>" in content_str:
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
if not stripped:
|
||||
skip_next_ai = True
|
||||
continue
|
||||
clean_msg = copy(msg)
|
||||
clean_msg.content = stripped
|
||||
filtered.append(clean_msg)
|
||||
skip_next_ai = False
|
||||
else:
|
||||
filtered.append(msg)
|
||||
skip_next_ai = False
|
||||
elif msg_type == "ai":
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
if skip_next_ai:
|
||||
skip_next_ai = False
|
||||
continue
|
||||
filtered.append(msg)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def detect_correction(messages: list[Any]) -> bool:
|
||||
"""Detect explicit user corrections in recent conversation turns."""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = extract_message_text(msg).strip()
|
||||
if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||
"""Detect explicit positive reinforcement signals in recent conversation turns."""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = extract_message_text(msg).strip()
|
||||
if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -61,48 +61,88 @@ class MemoryUpdateQueue:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
self._enqueue_locked(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
# If so, replace it with the newer one
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
self._queue.append(context)
|
||||
|
||||
# Reset or start the debounce timer
|
||||
self._reset_timer()
|
||||
|
||||
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
|
||||
|
||||
def add_nowait(
|
||||
self,
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation and start processing immediately in the background."""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._enqueue_locked(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
self._schedule_timer(0)
|
||||
|
||||
logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue))
|
||||
|
||||
def _enqueue_locked(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None,
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
) -> None:
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
self._queue.append(context)
|
||||
|
||||
def _reset_timer(self) -> None:
|
||||
"""Reset the debounce timer."""
|
||||
config = get_memory_config()
|
||||
self._schedule_timer(config.debounce_seconds)
|
||||
|
||||
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
|
||||
|
||||
def _schedule_timer(self, delay_seconds: float) -> None:
|
||||
"""Schedule queue processing after the provided delay."""
|
||||
# Cancel existing timer if any
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
|
||||
# Start new timer
|
||||
self._timer = threading.Timer(
|
||||
config.debounce_seconds,
|
||||
delay_seconds,
|
||||
self._process_queue,
|
||||
)
|
||||
self._timer.daemon = True
|
||||
self._timer.start()
|
||||
|
||||
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
|
||||
|
||||
def _process_queue(self) -> None:
|
||||
"""Process all queued conversation contexts."""
|
||||
# Import here to avoid circular dependency
|
||||
@ -110,8 +150,8 @@ class MemoryUpdateQueue:
|
||||
|
||||
with self._lock:
|
||||
if self._processing:
|
||||
# Already processing, reschedule
|
||||
self._reset_timer()
|
||||
# Preserve immediate flush semantics even if another worker is active.
|
||||
self._schedule_timer(0)
|
||||
return
|
||||
|
||||
if not self._queue:
|
||||
@ -164,6 +204,13 @@ class MemoryUpdateQueue:
|
||||
|
||||
self._process_queue()
|
||||
|
||||
def flush_nowait(self) -> None:
|
||||
"""Start queue processing immediately in a background thread."""
|
||||
with self._lock:
|
||||
# Daemon thread: queued messages may be lost if the process exits
|
||||
# before _process_queue completes. Acceptable for best-effort memory updates.
|
||||
self._schedule_timer(0)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the queue without processing.
|
||||
|
||||
|
||||
@ -0,0 +1,31 @@
|
||||
"""Hooks fired before summarization removes messages from state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
|
||||
def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
"""Flush messages about to be summarized into the memory queue."""
|
||||
if not get_memory_config().enabled or not event.thread_id:
|
||||
return
|
||||
|
||||
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
|
||||
user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"]
|
||||
assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"]
|
||||
if not user_messages or not assistant_messages:
|
||||
return
|
||||
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add_nowait(
|
||||
thread_id=event.thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=event.agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
@ -1,50 +1,19 @@
|
||||
"""Middleware for memory mechanism."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, override
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
_CORRECTION_PATTERNS = (
|
||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||
re.compile(r"不对"),
|
||||
re.compile(r"你理解错了"),
|
||||
re.compile(r"你理解有误"),
|
||||
re.compile(r"重试"),
|
||||
re.compile(r"重新来"),
|
||||
re.compile(r"换一种"),
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
_REINFORCEMENT_PATTERNS = (
|
||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
@ -52,125 +21,6 @@ class MemoryMiddlewareState(AgentState):
|
||||
pass
|
||||
|
||||
|
||||
def _extract_message_text(message: Any) -> str:
|
||||
"""Extract plain text from message content for filtering and signal detection."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text_val = part.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Filter messages to keep only user inputs and final assistant responses.
|
||||
|
||||
This filters out:
|
||||
- Tool messages (intermediate tool call results)
|
||||
- AI messages with tool_calls (intermediate steps, not final responses)
|
||||
- The <uploaded_files> block injected by UploadsMiddleware into human messages
|
||||
(file paths are session-scoped and must not persist in long-term memory).
|
||||
The user's actual question is preserved; only turns whose content is entirely
|
||||
the upload block (nothing remains after stripping) are dropped along with
|
||||
their paired assistant response.
|
||||
|
||||
Only keeps:
|
||||
- Human messages (with the ephemeral upload block removed)
|
||||
- AI messages without tool_calls (final assistant responses), unless the
|
||||
paired human turn was upload-only and had no real user text.
|
||||
|
||||
Args:
|
||||
messages: List of all conversation messages.
|
||||
|
||||
Returns:
|
||||
Filtered list containing only user inputs and final assistant responses.
|
||||
"""
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content_str = _extract_message_text(msg)
|
||||
if "<uploaded_files>" in content_str:
|
||||
# Strip the ephemeral upload block; keep the user's real question.
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
if not stripped:
|
||||
# Nothing left — the entire turn was upload bookkeeping;
|
||||
# skip it and the paired assistant response.
|
||||
skip_next_ai = True
|
||||
continue
|
||||
# Rebuild the message with cleaned content so the user's question
|
||||
# is still available for memory summarisation.
|
||||
from copy import copy
|
||||
|
||||
clean_msg = copy(msg)
|
||||
clean_msg.content = stripped
|
||||
filtered.append(clean_msg)
|
||||
skip_next_ai = False
|
||||
else:
|
||||
filtered.append(msg)
|
||||
skip_next_ai = False
|
||||
elif msg_type == "ai":
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
if skip_next_ai:
|
||||
skip_next_ai = False
|
||||
continue
|
||||
filtered.append(msg)
|
||||
# Skip tool messages and AI messages with tool_calls
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def detect_correction(messages: list[Any]) -> bool:
|
||||
"""Detect explicit user corrections in recent conversation turns.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale corrections from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||
|
||||
Complements detect_correction() by identifying when the user confirms the
|
||||
agent's approach was correct. This allows the memory system to record what
|
||||
worked well, not just what went wrong.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale signals from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
@ -223,7 +73,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
return None
|
||||
|
||||
# Filter to only keep user inputs and final assistant responses
|
||||
filtered_messages = _filter_messages_for_memory(messages)
|
||||
filtered_messages = filter_messages_for_memory(messages)
|
||||
|
||||
# Only queue if there's meaningful conversation
|
||||
# At minimum need one user message and one assistant response
|
||||
|
||||
@ -0,0 +1,151 @@
|
||||
"""Summarization middleware extensions for DeerFlow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
from langchain_core.messages import AnyMessage, RemoveMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SummarizationEvent:
|
||||
"""Context emitted before conversation history is summarized away."""
|
||||
|
||||
messages_to_summarize: tuple[AnyMessage, ...]
|
||||
preserved_messages: tuple[AnyMessage, ...]
|
||||
thread_id: str | None
|
||||
agent_name: str | None
|
||||
runtime: Runtime
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BeforeSummarizationHook(Protocol):
|
||||
"""Hook invoked before summarization removes messages from state."""
|
||||
|
||||
def __call__(self, event: SummarizationEvent) -> None: ...
|
||||
|
||||
|
||||
def _resolve_thread_id(runtime: Runtime) -> str | None:
|
||||
"""Resolve the current thread ID from runtime context or LangGraph config."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id is None:
|
||||
try:
|
||||
config_data = get_config()
|
||||
except RuntimeError:
|
||||
return None
|
||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
||||
return thread_id
|
||||
|
||||
|
||||
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
||||
"""Resolve the current agent name from runtime context or LangGraph config."""
|
||||
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
||||
if agent_name is None:
|
||||
try:
|
||||
config_data = get_config()
|
||||
except RuntimeError:
|
||||
return None
|
||||
agent_name = config_data.get("configurable", {}).get("agent_name")
|
||||
return agent_name
|
||||
|
||||
|
||||
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
"""Summarization middleware with pre-compression hook dispatch."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
before_summarization: list[BeforeSummarizationHook] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._before_summarization_hooks = before_summarization or []
|
||||
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._maybe_summarize(state, runtime)
|
||||
|
||||
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return await self._amaybe_summarize(state, runtime)
|
||||
|
||||
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||
summary = self._create_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||
summary = await self._acreate_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
def _fire_hooks(
|
||||
self,
|
||||
messages_to_summarize: list[AnyMessage],
|
||||
preserved_messages: list[AnyMessage],
|
||||
runtime: Runtime,
|
||||
) -> None:
|
||||
if not self._before_summarization_hooks:
|
||||
return
|
||||
|
||||
event = SummarizationEvent(
|
||||
messages_to_summarize=tuple(messages_to_summarize),
|
||||
preserved_messages=tuple(preserved_messages),
|
||||
thread_id=_resolve_thread_id(runtime),
|
||||
agent_name=_resolve_agent_name(runtime),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
for hook in self._before_summarization_hooks:
|
||||
try:
|
||||
hook(event)
|
||||
except Exception:
|
||||
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
|
||||
logger.exception("before_summarization hook %s failed", hook_name)
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
@ -145,6 +146,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_model = object()
|
||||
@ -156,10 +158,32 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
return fake_model
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert middleware["model"] is fake_model
|
||||
|
||||
|
||||
def test_create_summarization_middleware_registers_memory_flush_hook_when_memory_enabled(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_middleware(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
|
||||
|
||||
lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
@ -89,3 +91,74 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
|
||||
|
||||
def test_flush_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
existing_timer = MagicMock()
|
||||
queue._timer = existing_timer
|
||||
created_timer = MagicMock()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
|
||||
queue.flush_nowait()
|
||||
|
||||
existing_timer.cancel.assert_called_once_with()
|
||||
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||
assert created_timer.daemon is True
|
||||
created_timer.start.assert_called_once_with()
|
||||
assert queue._timer is created_timer
|
||||
|
||||
|
||||
def test_add_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
existing_timer = MagicMock()
|
||||
queue._timer = existing_timer
|
||||
created_timer = MagicMock()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls,
|
||||
):
|
||||
queue.add_nowait(thread_id="thread-1", messages=["conversation"], agent_name="lead-agent")
|
||||
|
||||
existing_timer.cancel.assert_called_once_with()
|
||||
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||
assert queue.pending_count == 1
|
||||
assert queue._queue[0].agent_name == "lead-agent"
|
||||
assert created_timer.daemon is True
|
||||
created_timer.start.assert_called_once_with()
|
||||
|
||||
|
||||
def test_process_queue_reschedules_immediately_when_already_processing() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue._processing = True
|
||||
created_timer = MagicMock()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
|
||||
queue._process_queue()
|
||||
|
||||
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||
assert created_timer.daemon is True
|
||||
created_timer.start.assert_called_once_with()
|
||||
|
||||
|
||||
def test_flush_nowait_is_non_blocking() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
started = threading.Event()
|
||||
finished = threading.Event()
|
||||
|
||||
def _slow_process_queue() -> None:
|
||||
started.set()
|
||||
time.sleep(0.2)
|
||||
finished.set()
|
||||
|
||||
queue._process_queue = _slow_process_queue
|
||||
|
||||
start = time.perf_counter()
|
||||
queue.flush_nowait()
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
assert started.wait(0.1) is True
|
||||
assert elapsed < 0.1
|
||||
assert finished.is_set() is False
|
||||
assert finished.wait(1.0) is True
|
||||
|
||||
@ -3,14 +3,14 @@
|
||||
Covers two functions introduced to prevent ephemeral file-upload context from
|
||||
persisting in long-term memory:
|
||||
|
||||
- _filter_messages_for_memory (memory_middleware)
|
||||
- filter_messages_for_memory (message_processing)
|
||||
- _strip_upload_mentions_from_memory (updater)
|
||||
"""
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@ -31,7 +31,7 @@ def _ai(text: str, tool_calls=None) -> AIMessage:
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _filter_messages_for_memory
|
||||
# filter_messages_for_memory
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ class TestFilterMessagesForMemory:
|
||||
_human(_UPLOAD_BLOCK),
|
||||
_ai("I have read the file. It says: Hello."),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
assert result == []
|
||||
|
||||
def test_upload_with_real_question_preserves_question(self):
|
||||
@ -56,7 +56,7 @@ class TestFilterMessagesForMemory:
|
||||
_human(combined),
|
||||
_ai("The file contains: Hello DeerFlow."),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
|
||||
assert len(result) == 2
|
||||
human_result = result[0]
|
||||
@ -71,7 +71,7 @@ class TestFilterMessagesForMemory:
|
||||
_human("What is the capital of France?"),
|
||||
_ai("The capital of France is Paris."),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "What is the capital of France?"
|
||||
assert result[1].content == "The capital of France is Paris."
|
||||
@ -84,7 +84,7 @@ class TestFilterMessagesForMemory:
|
||||
ToolMessage(content="Search results", tool_call_id="1"),
|
||||
_ai("Here are the results."),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
human_msgs = [m for m in result if m.type == "human"]
|
||||
ai_msgs = [m for m in result if m.type == "ai"]
|
||||
assert len(human_msgs) == 1
|
||||
@ -101,7 +101,7 @@ class TestFilterMessagesForMemory:
|
||||
_human("What is 2 + 2?"),
|
||||
_ai("4"),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
human_contents = [m.content for m in result if m.type == "human"]
|
||||
ai_contents = [m.content for m in result if m.type == "ai"]
|
||||
|
||||
@ -121,14 +121,14 @@ class TestFilterMessagesForMemory:
|
||||
]
|
||||
)
|
||||
msgs = [msg, _ai("Done.")]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
assert result == []
|
||||
|
||||
def test_file_path_not_in_filtered_content(self):
|
||||
"""After filtering, no upload file path should appear in any message."""
|
||||
combined = _UPLOAD_BLOCK + "\n\nSummarise the file please."
|
||||
msgs = [_human(combined), _ai("It says hello.")]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
all_content = " ".join(m.content for m in result if isinstance(m.content, str))
|
||||
assert "/mnt/user-data/uploads/" not in all_content
|
||||
assert "<uploaded_files>" not in all_content
|
||||
|
||||
186
backend/tests/test_summarization_middleware.py
Normal file
186
backend/tests/test_summarization_middleware.py
Normal file
@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
|
||||
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
|
||||
def _messages() -> list:
|
||||
return [
|
||||
HumanMessage(content="user-1"),
|
||||
AIMessage(content="assistant-1"),
|
||||
HumanMessage(content="user-2"),
|
||||
AIMessage(content="assistant-2"),
|
||||
]
|
||||
|
||||
|
||||
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||
context = {}
|
||||
if thread_id is not None:
|
||||
context["thread_id"] = thread_id
|
||||
if agent_name is not None:
|
||||
context["agent_name"] = agent_name
|
||||
return SimpleNamespace(context=context)
|
||||
|
||||
|
||||
def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware:
|
||||
model = MagicMock()
|
||||
model.invoke.return_value = SimpleNamespace(text="compressed summary")
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
model=model,
|
||||
trigger=trigger,
|
||||
keep=keep,
|
||||
token_counter=len,
|
||||
before_summarization=before_summarization,
|
||||
)
|
||||
|
||||
|
||||
def test_before_summarization_hook_receives_messages_before_compression() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
|
||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert len(captured) == 1
|
||||
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
|
||||
assert [message.content for message in captured[0].preserved_messages] == ["user-2", "assistant-2"]
|
||||
assert captured[0].thread_id == "thread-1"
|
||||
assert captured[0].agent_name is None
|
||||
assert isinstance(result["messages"][0], RemoveMessage)
|
||||
assert result["messages"][1].content.startswith("Here is a summary")
|
||||
|
||||
|
||||
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
|
||||
|
||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert captured == []
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_before_summarization_hook_exception_does_not_block_compression(caplog: pytest.LogCaptureFixture) -> None:
|
||||
def _broken_hook(_: SummarizationEvent) -> None:
|
||||
raise RuntimeError("hook failure")
|
||||
|
||||
middleware = _middleware(before_summarization=[_broken_hook])
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert "before_summarization hook _broken_hook failed" in caplog.text
|
||||
assert isinstance(result["messages"][0], RemoveMessage)
|
||||
|
||||
|
||||
def test_multiple_before_summarization_hooks_run_in_registration_order() -> None:
|
||||
call_order: list[str] = []
|
||||
|
||||
def _hook(name: str):
|
||||
return lambda _: call_order.append(name)
|
||||
|
||||
middleware = _middleware(before_summarization=[_hook("first"), _hook("second"), _hook("third")])
|
||||
|
||||
middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert call_order == ["first", "second", "third"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_abefore_model_calls_hooks_same_as_sync() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
|
||||
await middleware.abefore_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert len(captured) == 1
|
||||
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
|
||||
|
||||
|
||||
def test_memory_flush_hook_skips_when_memory_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id="thread-1",
|
||||
agent_name=None,
|
||||
runtime=_runtime(),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_not_called()
|
||||
|
||||
|
||||
def test_memory_flush_hook_skips_when_thread_id_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id=None,
|
||||
agent_name=None,
|
||||
runtime=_runtime(None),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_not_called()
|
||||
|
||||
|
||||
def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
messages = [
|
||||
HumanMessage(content="Question"),
|
||||
AIMessage(content="Calling tool", tool_calls=[{"name": "search", "id": "tool-1", "args": {}}]),
|
||||
AIMessage(content="Final answer"),
|
||||
]
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(messages),
|
||||
preserved_messages=(),
|
||||
thread_id="thread-1",
|
||||
agent_name=None,
|
||||
runtime=_runtime(),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
add_kwargs = queue.add_nowait.call_args.kwargs
|
||||
assert add_kwargs["thread_id"] == "thread-1"
|
||||
assert [message.content for message in add_kwargs["messages"]] == ["Question", "Final answer"]
|
||||
assert add_kwargs["correction_detected"] is False
|
||||
assert add_kwargs["reinforcement_detected"] is False
|
||||
|
||||
|
||||
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id="thread-1",
|
||||
agent_name="research-agent",
|
||||
runtime=_runtime(agent_name="research-agent"),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||
Loading…
x
Reference in New Issue
Block a user