mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
* feat(memory): add structured reflection and correction detection * fix(memory): align sourceError schema and prompt guidance --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
3e461d9d08
commit
0cdecf7b30
@ -49,6 +49,7 @@ class Fact(BaseModel):
|
||||
confidence: float = Field(default=0.5, description="Confidence score (0-1)")
|
||||
createdAt: str = Field(default="", description="Creation timestamp")
|
||||
source: str = Field(default="unknown", description="Source thread ID")
|
||||
sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach")
|
||||
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
@ -108,6 +109,7 @@ class MemoryStatusResponse(BaseModel):
|
||||
@router.get(
|
||||
"/memory",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Get Memory Data",
|
||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||
)
|
||||
@ -152,6 +154,7 @@ async def get_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/reload",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Reload Memory Data",
|
||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||
)
|
||||
@ -171,6 +174,7 @@ async def reload_memory() -> MemoryResponse:
|
||||
@router.delete(
|
||||
"/memory",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Clear All Memory Data",
|
||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||
)
|
||||
@ -187,6 +191,7 @@ async def clear_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/facts",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Create Memory Fact",
|
||||
description="Create a single saved memory fact manually.",
|
||||
)
|
||||
@ -209,6 +214,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||
@router.delete(
|
||||
"/memory/facts/{fact_id}",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Delete Memory Fact",
|
||||
description="Delete a single saved memory fact by its fact id.",
|
||||
)
|
||||
@ -227,6 +233,7 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
@router.patch(
|
||||
"/memory/facts/{fact_id}",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Patch Memory Fact",
|
||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||
)
|
||||
@ -252,6 +259,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
@router.get(
|
||||
"/memory/export",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Export Memory Data",
|
||||
description="Export the current global memory data as JSON for backup or transfer.",
|
||||
)
|
||||
@ -264,6 +272,7 @@ async def export_memory() -> MemoryResponse:
|
||||
@router.post(
|
||||
"/memory/import",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Import Memory Data",
|
||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||
)
|
||||
@ -317,6 +326,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
@router.get(
|
||||
"/memory/status",
|
||||
response_model=MemoryStatusResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Get Memory Status",
|
||||
description="Retrieve both memory configuration and current data in a single request.",
|
||||
)
|
||||
|
||||
@ -29,6 +29,17 @@ Instructions:
|
||||
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
|
||||
3. Update the memory sections as needed following the detailed length guidelines below
|
||||
|
||||
Before extracting facts, perform a structured reflection on the conversation:
|
||||
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
|
||||
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
|
||||
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
|
||||
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
|
||||
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
|
||||
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
|
||||
If yes, record them as facts with the most appropriate category and confidence.
|
||||
|
||||
{correction_hint}
|
||||
|
||||
Memory Section Guidelines:
|
||||
|
||||
**User Context** (Current state - concise summaries):
|
||||
@ -62,6 +73,7 @@ Memory Section Guidelines:
|
||||
* context: Background facts (job title, projects, locations, languages)
|
||||
* behavior: Working patterns, communication habits, problem-solving approaches
|
||||
* goal: Stated objectives, learning targets, project ambitions
|
||||
* correction: Explicit agent mistakes or user corrections, including the correct approach
|
||||
- Confidence levels:
|
||||
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
|
||||
* 0.7-0.8: Strongly implied from actions/discussions
|
||||
@ -94,7 +106,7 @@ Output Format (JSON):
|
||||
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||
}},
|
||||
"newFacts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
],
|
||||
"factsToRemove": ["fact_id_1", "fact_id_2"]
|
||||
}}
|
||||
@ -104,6 +116,8 @@ Important Rules:
|
||||
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
|
||||
- Include specific metrics, version numbers, and proper nouns in facts
|
||||
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
|
||||
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
|
||||
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
|
||||
- Remove facts that are contradicted by new information
|
||||
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
|
||||
Keep 3-5 concurrent focus themes that are still active and relevant
|
||||
@ -126,7 +140,7 @@ Message:
|
||||
Extract facts in this JSON format:
|
||||
{{
|
||||
"facts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
]
|
||||
}}
|
||||
|
||||
@ -136,6 +150,7 @@ Categories:
|
||||
- context: Background context (location, job, projects)
|
||||
- behavior: Behavioral patterns
|
||||
- goal: User's goals or objectives
|
||||
- correction: Explicit corrections or mistakes to avoid repeating
|
||||
|
||||
Rules:
|
||||
- Only extract clear, specific facts
|
||||
@ -262,7 +277,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
continue
|
||||
category = str(fact.get("category", "context")).strip() or "context"
|
||||
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
source_error = fact.get("sourceError")
|
||||
if category == "correction" and isinstance(source_error, str) and source_error.strip():
|
||||
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
|
||||
else:
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
|
||||
@ -20,6 +20,7 @@ class ConversationContext:
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
@ -37,25 +38,38 @@ class MemoryUpdateQueue:
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
|
||||
def add(
|
||||
self,
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
# If so, replace it with the newer one
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
@ -115,6 +129,7 @@ class MemoryUpdateQueue:
|
||||
messages=context.messages,
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
||||
@ -266,13 +266,20 @@ class MemoryUpdater:
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
def update_memory(
|
||||
self,
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@ -295,9 +302,19 @@ class MemoryUpdater:
|
||||
return False
|
||||
|
||||
# Build prompt
|
||||
correction_hint = ""
|
||||
if correction_detected:
|
||||
correction_hint = (
|
||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
correction_hint=correction_hint,
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
@ -383,6 +400,8 @@ class MemoryUpdater:
|
||||
confidence = fact.get("confidence", 0.5)
|
||||
if confidence >= config.fact_confidence_threshold:
|
||||
raw_content = fact.get("content", "")
|
||||
if not isinstance(raw_content, str):
|
||||
continue
|
||||
normalized_content = raw_content.strip()
|
||||
fact_key = _fact_content_key(normalized_content)
|
||||
if fact_key is not None and fact_key in existing_fact_keys:
|
||||
@ -396,6 +415,11 @@ class MemoryUpdater:
|
||||
"createdAt": now,
|
||||
"source": thread_id or "unknown",
|
||||
}
|
||||
source_error = fact.get("sourceError")
|
||||
if isinstance(source_error, str):
|
||||
normalized_source_error = source_error.strip()
|
||||
if normalized_source_error:
|
||||
fact_entry["sourceError"] = normalized_source_error
|
||||
current_memory["facts"].append(fact_entry)
|
||||
if fact_key is not None:
|
||||
existing_fact_keys.add(fact_key)
|
||||
@ -412,16 +436,22 @@ class MemoryUpdater:
|
||||
return current_memory
|
||||
|
||||
|
||||
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
def update_memory_from_conversation(
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
|
||||
|
||||
@ -14,6 +14,21 @@ from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
_CORRECTION_PATTERNS = (
|
||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||
re.compile(r"不对"),
|
||||
re.compile(r"你理解错了"),
|
||||
re.compile(r"你理解有误"),
|
||||
re.compile(r"重试"),
|
||||
re.compile(r"重新来"),
|
||||
re.compile(r"换一种"),
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
@ -21,6 +36,22 @@ class MemoryMiddlewareState(AgentState):
|
||||
pass
|
||||
|
||||
|
||||
def _extract_message_text(message: Any) -> str:
|
||||
"""Extract plain text from message content for filtering and signal detection."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text_val = part.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Filter messages to keep only user inputs and final assistant responses.
|
||||
|
||||
@ -44,18 +75,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
Returns:
|
||||
Filtered list containing only user inputs and final assistant responses.
|
||||
"""
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, list):
|
||||
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
|
||||
content_str = str(content)
|
||||
content_str = _extract_message_text(msg)
|
||||
if "<uploaded_files>" in content_str:
|
||||
# Strip the ephemeral upload block; keep the user's real question.
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
@ -87,6 +113,25 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
return filtered
|
||||
|
||||
|
||||
def detect_correction(messages: list[Any]) -> bool:
|
||||
"""Detect explicit user corrections in recent conversation turns.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale corrections from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
@ -150,7 +195,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
return None
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -119,3 +119,38 @@ def test_format_memory_skips_non_string_content_facts() -> None:
|
||||
# The formatted line for a list content would be "- [knowledge | 0.85] ['list']".
|
||||
assert "| 0.85]" not in result
|
||||
assert "Valid fact" in result
|
||||
|
||||
|
||||
def test_format_memory_renders_correction_source_error() -> None:
|
||||
memory_data = {
|
||||
"facts": [
|
||||
{
|
||||
"content": "Use make dev for local development.",
|
||||
"category": "correction",
|
||||
"confidence": 0.95,
|
||||
"sourceError": "The agent previously suggested npm start.",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||
|
||||
assert "Use make dev for local development." in result
|
||||
assert "avoid: The agent previously suggested npm start." in result
|
||||
|
||||
|
||||
def test_format_memory_renders_correction_without_source_error_normally() -> None:
|
||||
memory_data = {
|
||||
"facts": [
|
||||
{
|
||||
"content": "Use make dev for local development.",
|
||||
"category": "correction",
|
||||
"confidence": 0.95,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||
|
||||
assert "Use make dev for local development." in result
|
||||
assert "avoid:" not in result
|
||||
|
||||
50
backend/tests/test_memory_queue.py
Normal file
50
backend/tests/test_memory_queue.py
Normal file
@ -0,0 +1,50 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
|
||||
def _memory_config(**overrides: object) -> MemoryConfig:
|
||||
config = MemoryConfig()
|
||||
for key, value in overrides.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], correction_detected=False)
|
||||
|
||||
assert len(queue._queue) == 1
|
||||
assert queue._queue[0].messages == ["second"]
|
||||
assert queue._queue[0].correction_detected is True
|
||||
|
||||
|
||||
def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
messages=["conversation"],
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
)
|
||||
]
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
|
||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||
queue._process_queue()
|
||||
|
||||
mock_updater.update_memory.assert_called_once_with(
|
||||
messages=["conversation"],
|
||||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
)
|
||||
@ -72,6 +72,56 @@ def test_import_memory_route_returns_imported_memory() -> None:
|
||||
assert response.json()["facts"] == imported_memory["facts"]
|
||||
|
||||
|
||||
def test_export_memory_route_preserves_source_error() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
exported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_correction",
|
||||
"content": "Use make dev for local development.",
|
||||
"category": "correction",
|
||||
"confidence": 0.95,
|
||||
"createdAt": "2026-03-20T00:00:00Z",
|
||||
"source": "thread-1",
|
||||
"sourceError": "The agent previously suggested npm start.",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with patch("app.gateway.routers.memory.get_memory_data", return_value=exported_memory):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/memory/export")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||
|
||||
|
||||
def test_import_memory_route_preserves_source_error() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
imported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_correction",
|
||||
"content": "Use make dev for local development.",
|
||||
"category": "correction",
|
||||
"confidence": 0.95,
|
||||
"createdAt": "2026-03-20T00:00:00Z",
|
||||
"source": "thread-1",
|
||||
"sourceError": "The agent previously suggested npm start.",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with patch("app.gateway.routers.memory.import_memory_data", return_value=imported_memory):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/memory/import", json=imported_memory)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||
|
||||
|
||||
def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
|
||||
@ -146,6 +146,53 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
assert result["facts"][1]["source"] == "thread-9"
|
||||
|
||||
|
||||
def test_apply_updates_preserves_source_error() -> None:
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
{
|
||||
"content": "Use make dev for local development.",
|
||||
"category": "correction",
|
||||
"confidence": 0.95,
|
||||
"sourceError": "The agent previously suggested npm start.",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
|
||||
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||
assert result["facts"][0]["category"] == "correction"
|
||||
|
||||
|
||||
def test_apply_updates_ignores_empty_source_error() -> None:
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
{
|
||||
"content": "Use make dev for local development.",
|
||||
"category": "correction",
|
||||
"confidence": 0.95,
|
||||
"sourceError": " ",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
|
||||
assert "sourceError" not in result["facts"][0]
|
||||
|
||||
|
||||
def test_clear_memory_data_resets_all_sections() -> None:
|
||||
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
|
||||
result = clear_memory_data()
|
||||
@ -522,3 +569,53 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_correction_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "No, that's wrong."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Understood"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
|
||||
def test_correction_hint_empty_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Let's talk about memory."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Sure"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
@ -10,7 +10,7 @@ persisting in long-term memory:
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@ -134,6 +134,64 @@ class TestFilterMessagesForMemory:
|
||||
assert "<uploaded_files>" not in all_content
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# detect_correction
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestDetectCorrection:
|
||||
def test_detects_english_correction_signal(self):
|
||||
msgs = [
|
||||
_human("Please help me run the project."),
|
||||
_ai("Use npm start."),
|
||||
_human("That's wrong, use make dev instead."),
|
||||
_ai("Understood."),
|
||||
]
|
||||
|
||||
assert detect_correction(msgs) is True
|
||||
|
||||
def test_detects_chinese_correction_signal(self):
|
||||
msgs = [
|
||||
_human("帮我启动项目"),
|
||||
_ai("用 npm start"),
|
||||
_human("不对,改用 make dev"),
|
||||
_ai("明白了"),
|
||||
]
|
||||
|
||||
assert detect_correction(msgs) is True
|
||||
|
||||
def test_returns_false_without_signal(self):
|
||||
msgs = [
|
||||
_human("Please explain the build setup."),
|
||||
_ai("Here is the build setup."),
|
||||
_human("Thanks, that makes sense."),
|
||||
]
|
||||
|
||||
assert detect_correction(msgs) is False
|
||||
|
||||
def test_only_checks_recent_messages(self):
|
||||
msgs = [
|
||||
_human("That is wrong, use make dev instead."),
|
||||
_ai("Noted."),
|
||||
_human("Let's discuss tests."),
|
||||
_ai("Sure."),
|
||||
_human("What about linting?"),
|
||||
_ai("Use ruff."),
|
||||
_human("And formatting?"),
|
||||
_ai("Use make format."),
|
||||
]
|
||||
|
||||
assert detect_correction(msgs) is False
|
||||
|
||||
def test_handles_list_content(self):
|
||||
msgs = [
|
||||
HumanMessage(content=["That is wrong,", {"type": "text", "text": "use make dev instead."}]),
|
||||
_ai("Updated."),
|
||||
]
|
||||
|
||||
assert detect_correction(msgs) is True
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _strip_upload_mentions_from_memory
|
||||
# ===========================================================================
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user