diff --git a/backend/app/gateway/routers/memory.py b/backend/app/gateway/routers/memory.py index 7a074100a..6ee546924 100644 --- a/backend/app/gateway/routers/memory.py +++ b/backend/app/gateway/routers/memory.py @@ -49,6 +49,7 @@ class Fact(BaseModel): confidence: float = Field(default=0.5, description="Confidence score (0-1)") createdAt: str = Field(default="", description="Creation timestamp") source: str = Field(default="unknown", description="Source thread ID") + sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach") class MemoryResponse(BaseModel): @@ -108,6 +109,7 @@ class MemoryStatusResponse(BaseModel): @router.get( "/memory", response_model=MemoryResponse, + response_model_exclude_none=True, summary="Get Memory Data", description="Retrieve the current global memory data including user context, history, and facts.", ) @@ -152,6 +154,7 @@ async def get_memory() -> MemoryResponse: @router.post( "/memory/reload", response_model=MemoryResponse, + response_model_exclude_none=True, summary="Reload Memory Data", description="Reload memory data from the storage file, refreshing the in-memory cache.", ) @@ -171,6 +174,7 @@ async def reload_memory() -> MemoryResponse: @router.delete( "/memory", response_model=MemoryResponse, + response_model_exclude_none=True, summary="Clear All Memory Data", description="Delete all saved memory data and reset the memory structure to an empty state.", ) @@ -187,6 +191,7 @@ async def clear_memory() -> MemoryResponse: @router.post( "/memory/facts", response_model=MemoryResponse, + response_model_exclude_none=True, summary="Create Memory Fact", description="Create a single saved memory fact manually.", ) @@ -209,6 +214,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo @router.delete( "/memory/facts/{fact_id}", response_model=MemoryResponse, + response_model_exclude_none=True, summary="Delete Memory Fact", description="Delete a single saved memory fact by its fact id.", ) @@ -227,6 +233,7 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse: @router.patch( "/memory/facts/{fact_id}", response_model=MemoryResponse, + response_model_exclude_none=True, summary="Patch Memory Fact", description="Partially update a single saved memory fact by its fact id while preserving omitted fields.", ) @@ -252,6 +259,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) - @router.get( "/memory/export", response_model=MemoryResponse, + response_model_exclude_none=True, summary="Export Memory Data", description="Export the current global memory data as JSON for backup or transfer.", ) @@ -264,6 +272,7 @@ async def export_memory() -> MemoryResponse: @router.post( "/memory/import", response_model=MemoryResponse, + response_model_exclude_none=True, summary="Import Memory Data", description="Import and overwrite the current global memory data from a JSON payload.", ) @@ -317,6 +326,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse: @router.get( "/memory/status", response_model=MemoryStatusResponse, + response_model_exclude_none=True, summary="Get Memory Status", description="Retrieve both memory configuration and current data in a single request.", ) diff --git a/backend/packages/harness/deerflow/agents/memory/prompt.py b/backend/packages/harness/deerflow/agents/memory/prompt.py index 0d4e86d97..e0c04b77e 100644 --- a/backend/packages/harness/deerflow/agents/memory/prompt.py +++ b/backend/packages/harness/deerflow/agents/memory/prompt.py @@ -29,6 +29,17 @@ Instructions: 2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies) 3. Update the memory sections as needed following the detailed length guidelines below +Before extracting facts, perform a structured reflection on the conversation: +1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results? + If yes, record the root cause and correct approach as a high-confidence fact with category "correction". +2. User Correction Detection: Did the user correct the agent's direction, understanding, or output? + If yes, record the correct interpretation or approach as a high-confidence fact with category "correction". + Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation. +3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation? + If yes, record them as facts with the most appropriate category and confidence. + +{correction_hint} + Memory Section Guidelines: **User Context** (Current state - concise summaries): @@ -62,6 +73,7 @@ Memory Section Guidelines: * context: Background facts (job title, projects, locations, languages) * behavior: Working patterns, communication habits, problem-solving approaches * goal: Stated objectives, learning targets, project ambitions + * correction: Explicit agent mistakes or user corrections, including the correct approach - Confidence levels: * 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y") * 0.7-0.8: Strongly implied from actions/discussions @@ -94,7 +106,7 @@ Output Format (JSON): "longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }} }}, "newFacts": [ - {{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }} + {{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }} ], "factsToRemove": ["fact_id_1", "fact_id_2"] }} @@ -104,6 +116,8 @@ Important Rules: - Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs) - Include specific metrics, version numbers, and proper nouns in facts - Only add facts that are clearly stated (0.9+) or strongly implied (0.7+) +- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit +- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise - Remove facts that are contradicted by new information - When updating topOfMind, integrate new focus areas while removing completed/abandoned ones Keep 3-5 concurrent focus themes that are still active and relevant @@ -126,7 +140,7 @@ Message: Extract facts in this JSON format: {{ "facts": [ - {{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }} + {{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }} ] }} @@ -136,6 +150,7 @@ Categories: - context: Background context (location, job, projects) - behavior: Behavioral patterns - goal: User's goals or objectives +- correction: Explicit corrections or mistakes to avoid repeating Rules: - Only extract clear, specific facts @@ -262,7 +277,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2 continue category = str(fact.get("category", "context")).strip() or "context" confidence = _coerce_confidence(fact.get("confidence"), default=0.0) - line = f"- [{category} | {confidence:.2f}] {content}" + source_error = fact.get("sourceError") + if category == "correction" and isinstance(source_error, str) and source_error.strip(): + line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})" + else: + line = f"- [{category} | {confidence:.2f}] {content}" # Each additional line is preceded by a newline (except the first). line_text = ("\n" + line) if fact_lines else line diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index a9a683860..6d777a67e 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -20,6 +20,7 @@ class ConversationContext: messages: list[Any] timestamp: datetime = field(default_factory=datetime.utcnow) agent_name: str | None = None + correction_detected: bool = False class MemoryUpdateQueue: @@ -37,25 +38,38 @@ class MemoryUpdateQueue: self._timer: threading.Timer | None = None self._processing = False - def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None: + def add( + self, + thread_id: str, + messages: list[Any], + agent_name: str | None = None, + correction_detected: bool = False, + ) -> None: """Add a conversation to the update queue. Args: thread_id: The thread ID. messages: The conversation messages. agent_name: If provided, memory is stored per-agent. If None, uses global memory. + correction_detected: Whether recent turns include an explicit correction signal. """ config = get_memory_config() if not config.enabled: return - context = ConversationContext( - thread_id=thread_id, - messages=messages, - agent_name=agent_name, - ) - with self._lock: + existing_context = next( + (context for context in self._queue if context.thread_id == thread_id), + None, + ) + merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False) + context = ConversationContext( + thread_id=thread_id, + messages=messages, + agent_name=agent_name, + correction_detected=merged_correction_detected, + ) + # Check if this thread already has a pending update # If so, replace it with the newer one self._queue = [c for c in self._queue if c.thread_id != thread_id] @@ -115,6 +129,7 @@ class MemoryUpdateQueue: messages=context.messages, thread_id=context.thread_id, agent_name=context.agent_name, + correction_detected=context.correction_detected, ) if success: logger.info("Memory updated successfully for thread %s", context.thread_id) diff --git a/backend/packages/harness/deerflow/agents/memory/updater.py b/backend/packages/harness/deerflow/agents/memory/updater.py index e8c8b5898..c59749d7b 100644 --- a/backend/packages/harness/deerflow/agents/memory/updater.py +++ b/backend/packages/harness/deerflow/agents/memory/updater.py @@ -266,13 +266,20 @@ class MemoryUpdater: model_name = self._model_name or config.model_name return create_chat_model(name=model_name, thinking_enabled=False) - def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool: + def update_memory( + self, + messages: list[Any], + thread_id: str | None = None, + agent_name: str | None = None, + correction_detected: bool = False, + ) -> bool: """Update memory based on conversation messages. Args: messages: List of conversation messages. thread_id: Optional thread ID for tracking source. agent_name: If provided, updates per-agent memory. If None, updates global memory. + correction_detected: Whether recent turns include an explicit correction signal. Returns: True if update was successful, False otherwise. @@ -295,9 +302,19 @@ class MemoryUpdater: return False # Build prompt + correction_hint = "" + if correction_detected: + correction_hint = ( + "IMPORTANT: Explicit correction signals were detected in this conversation. " + "Pay special attention to what the agent got wrong, what the user corrected, " + "and record the correct approach as a fact with category " + '"correction" and confidence >= 0.95 when appropriate.' + ) + prompt = MEMORY_UPDATE_PROMPT.format( current_memory=json.dumps(current_memory, indent=2), conversation=conversation_text, + correction_hint=correction_hint, ) # Call LLM @@ -383,6 +400,8 @@ class MemoryUpdater: confidence = fact.get("confidence", 0.5) if confidence >= config.fact_confidence_threshold: raw_content = fact.get("content", "") + if not isinstance(raw_content, str): + continue normalized_content = raw_content.strip() fact_key = _fact_content_key(normalized_content) if fact_key is not None and fact_key in existing_fact_keys: @@ -396,6 +415,11 @@ class MemoryUpdater: "createdAt": now, "source": thread_id or "unknown", } + source_error = fact.get("sourceError") + if isinstance(source_error, str): + normalized_source_error = source_error.strip() + if normalized_source_error: + fact_entry["sourceError"] = normalized_source_error current_memory["facts"].append(fact_entry) if fact_key is not None: existing_fact_keys.add(fact_key) @@ -412,16 +436,22 @@ class MemoryUpdater: return current_memory -def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool: +def update_memory_from_conversation( + messages: list[Any], + thread_id: str | None = None, + agent_name: str | None = None, + correction_detected: bool = False, +) -> bool: """Convenience function to update memory from a conversation. Args: messages: List of conversation messages. thread_id: Optional thread ID. agent_name: If provided, updates per-agent memory. If None, updates global memory. + correction_detected: Whether recent turns include an explicit correction signal. Returns: True if successful, False otherwise. """ updater = MemoryUpdater() - return updater.update_memory(messages, thread_id, agent_name) + return updater.update_memory(messages, thread_id, agent_name, correction_detected) diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py index 90907b6e9..6215a2957 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py @@ -14,6 +14,21 @@ from deerflow.config.memory_config import get_memory_config logger = logging.getLogger(__name__) +_UPLOAD_BLOCK_RE = re.compile(r"[\s\S]*?\n*", re.IGNORECASE) +_CORRECTION_PATTERNS = ( + re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE), + re.compile(r"\byou misunderstood\b", re.IGNORECASE), + re.compile(r"\btry again\b", re.IGNORECASE), + re.compile(r"\bredo\b", re.IGNORECASE), + re.compile(r"不对"), + re.compile(r"你理解错了"), + re.compile(r"你理解有误"), + re.compile(r"重试"), + re.compile(r"重新来"), + re.compile(r"换一种"), + re.compile(r"改用"), +) + class MemoryMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" @@ -21,6 +36,22 @@ class MemoryMiddlewareState(AgentState): pass +def _extract_message_text(message: Any) -> str: + """Extract plain text from message content for filtering and signal detection.""" + content = getattr(message, "content", "") + if isinstance(content, list): + text_parts: list[str] = [] + for part in content: + if isinstance(part, str): + text_parts.append(part) + elif isinstance(part, dict): + text_val = part.get("text") + if isinstance(text_val, str): + text_parts.append(text_val) + return " ".join(text_parts) + return str(content) + + def _filter_messages_for_memory(messages: list[Any]) -> list[Any]: """Filter messages to keep only user inputs and final assistant responses. @@ -44,18 +75,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]: Returns: Filtered list containing only user inputs and final assistant responses. """ - _UPLOAD_BLOCK_RE = re.compile(r"[\s\S]*?\n*", re.IGNORECASE) - filtered = [] skip_next_ai = False for msg in messages: msg_type = getattr(msg, "type", None) if msg_type == "human": - content = getattr(msg, "content", "") - if isinstance(content, list): - content = " ".join(p.get("text", "") for p in content if isinstance(p, dict)) - content_str = str(content) + content_str = _extract_message_text(msg) if "" in content_str: # Strip the ephemeral upload block; keep the user's real question. stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip() @@ -87,6 +113,25 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]: return filtered +def detect_correction(messages: list[Any]) -> bool: + """Detect explicit user corrections in recent conversation turns. + + The queue keeps only one pending context per thread, so callers pass the + latest filtered message list. Checking only recent user turns keeps signal + detection conservative while avoiding stale corrections from long histories. + """ + recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"] + + for msg in recent_user_msgs: + content = _extract_message_text(msg).strip() + if not content: + continue + if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS): + return True + + return False + + class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): """Middleware that queues conversation for memory update after agent execution. @@ -150,7 +195,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): return None # Queue the filtered conversation for memory update + correction_detected = detect_correction(filtered_messages) queue = get_memory_queue() - queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name) + queue.add( + thread_id=thread_id, + messages=filtered_messages, + agent_name=self._agent_name, + correction_detected=correction_detected, + ) return None diff --git a/backend/tests/test_memory_prompt_injection.py b/backend/tests/test_memory_prompt_injection.py index ab1f0a783..d33b69a92 100644 --- a/backend/tests/test_memory_prompt_injection.py +++ b/backend/tests/test_memory_prompt_injection.py @@ -119,3 +119,38 @@ def test_format_memory_skips_non_string_content_facts() -> None: # The formatted line for a list content would be "- [knowledge | 0.85] ['list']". assert "| 0.85]" not in result assert "Valid fact" in result + + +def test_format_memory_renders_correction_source_error() -> None: + memory_data = { + "facts": [ + { + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "sourceError": "The agent previously suggested npm start.", + } + ] + } + + result = format_memory_for_injection(memory_data, max_tokens=2000) + + assert "Use make dev for local development." in result + assert "avoid: The agent previously suggested npm start." in result + + +def test_format_memory_renders_correction_without_source_error_normally() -> None: + memory_data = { + "facts": [ + { + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + } + ] + } + + result = format_memory_for_injection(memory_data, max_tokens=2000) + + assert "Use make dev for local development." in result + assert "avoid:" not in result diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py new file mode 100644 index 000000000..6ef91a142 --- /dev/null +++ b/backend/tests/test_memory_queue.py @@ -0,0 +1,50 @@ +from unittest.mock import MagicMock, patch + +from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue +from deerflow.config.memory_config import MemoryConfig + + +def _memory_config(**overrides: object) -> MemoryConfig: + config = MemoryConfig() + for key, value in overrides.items(): + setattr(config, key, value) + return config + + +def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["first"], correction_detected=True) + queue.add(thread_id="thread-1", messages=["second"], correction_detected=False) + + assert len(queue._queue) == 1 + assert queue._queue[0].messages == ["second"] + assert queue._queue[0].correction_detected is True + + +def test_process_queue_forwards_correction_flag_to_updater() -> None: + queue = MemoryUpdateQueue() + queue._queue = [ + ConversationContext( + thread_id="thread-1", + messages=["conversation"], + agent_name="lead_agent", + correction_detected=True, + ) + ] + mock_updater = MagicMock() + mock_updater.update_memory.return_value = True + + with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater): + queue._process_queue() + + mock_updater.update_memory.assert_called_once_with( + messages=["conversation"], + thread_id="thread-1", + agent_name="lead_agent", + correction_detected=True, + ) diff --git a/backend/tests/test_memory_router.py b/backend/tests/test_memory_router.py index 39134c61d..23a4f30fe 100644 --- a/backend/tests/test_memory_router.py +++ b/backend/tests/test_memory_router.py @@ -72,6 +72,56 @@ def test_import_memory_route_returns_imported_memory() -> None: assert response.json()["facts"] == imported_memory["facts"] +def test_export_memory_route_preserves_source_error() -> None: + app = FastAPI() + app.include_router(memory.router) + exported_memory = _sample_memory( + facts=[ + { + "id": "fact_correction", + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "createdAt": "2026-03-20T00:00:00Z", + "source": "thread-1", + "sourceError": "The agent previously suggested npm start.", + } + ] + ) + + with patch("app.gateway.routers.memory.get_memory_data", return_value=exported_memory): + with TestClient(app) as client: + response = client.get("/api/memory/export") + + assert response.status_code == 200 + assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start." + + +def test_import_memory_route_preserves_source_error() -> None: + app = FastAPI() + app.include_router(memory.router) + imported_memory = _sample_memory( + facts=[ + { + "id": "fact_correction", + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "createdAt": "2026-03-20T00:00:00Z", + "source": "thread-1", + "sourceError": "The agent previously suggested npm start.", + } + ] + ) + + with patch("app.gateway.routers.memory.import_memory_data", return_value=imported_memory): + with TestClient(app) as client: + response = client.post("/api/memory/import", json=imported_memory) + + assert response.status_code == 200 + assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start." + + def test_clear_memory_route_returns_cleared_memory() -> None: app = FastAPI() app.include_router(memory.router) diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index f7b48228a..6309cf9f6 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -146,6 +146,53 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None: assert result["facts"][1]["source"] == "thread-9" +def test_apply_updates_preserves_source_error() -> None: + updater = MemoryUpdater() + current_memory = _make_memory() + update_data = { + "newFacts": [ + { + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "sourceError": "The agent previously suggested npm start.", + } + ] + } + + with patch( + "deerflow.agents.memory.updater.get_memory_config", + return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7), + ): + result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction") + + assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start." + assert result["facts"][0]["category"] == "correction" + + +def test_apply_updates_ignores_empty_source_error() -> None: + updater = MemoryUpdater() + current_memory = _make_memory() + update_data = { + "newFacts": [ + { + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "sourceError": " ", + } + ] + } + + with patch( + "deerflow.agents.memory.updater.get_memory_config", + return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7), + ): + result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction") + + assert "sourceError" not in result["facts"][0] + + def test_clear_memory_data_resets_all_sections() -> None: with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True): result = clear_memory_data() @@ -522,3 +569,53 @@ class TestUpdateMemoryStructuredResponse: result = updater.update_memory([msg, ai_msg]) assert result is True + + def test_correction_hint_injected_when_detected(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "No, that's wrong." + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Understood" + ai_msg.tool_calls = [] + + result = updater.update_memory([msg, ai_msg], correction_detected=True) + + assert result is True + prompt = model.invoke.call_args[0][0] + assert "Explicit correction signals were detected" in prompt + + def test_correction_hint_empty_when_not_detected(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Let's talk about memory." + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Sure" + ai_msg.tool_calls = [] + + result = updater.update_memory([msg, ai_msg], correction_detected=False) + + assert result is True + prompt = model.invoke.call_args[0][0] + assert "Explicit correction signals were detected" not in prompt diff --git a/backend/tests/test_memory_upload_filtering.py b/backend/tests/test_memory_upload_filtering.py index 45d0dbf4e..1ff0aa3b6 100644 --- a/backend/tests/test_memory_upload_filtering.py +++ b/backend/tests/test_memory_upload_filtering.py @@ -10,7 +10,7 @@ persisting in long-term memory: from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory -from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory +from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction # --------------------------------------------------------------------------- # Helpers @@ -134,6 +134,64 @@ class TestFilterMessagesForMemory: assert "" not in all_content +# =========================================================================== +# detect_correction +# =========================================================================== + + +class TestDetectCorrection: + def test_detects_english_correction_signal(self): + msgs = [ + _human("Please help me run the project."), + _ai("Use npm start."), + _human("That's wrong, use make dev instead."), + _ai("Understood."), + ] + + assert detect_correction(msgs) is True + + def test_detects_chinese_correction_signal(self): + msgs = [ + _human("帮我启动项目"), + _ai("用 npm start"), + _human("不对,改用 make dev"), + _ai("明白了"), + ] + + assert detect_correction(msgs) is True + + def test_returns_false_without_signal(self): + msgs = [ + _human("Please explain the build setup."), + _ai("Here is the build setup."), + _human("Thanks, that makes sense."), + ] + + assert detect_correction(msgs) is False + + def test_only_checks_recent_messages(self): + msgs = [ + _human("That is wrong, use make dev instead."), + _ai("Noted."), + _human("Let's discuss tests."), + _ai("Sure."), + _human("What about linting?"), + _ai("Use ruff."), + _human("And formatting?"), + _ai("Use make format."), + ] + + assert detect_correction(msgs) is False + + def test_handles_list_content(self): + msgs = [ + HumanMessage(content=["That is wrong,", {"type": "text", "text": "use make dev instead."}]), + _ai("Updated."), + ] + + assert detect_correction(msgs) is True + + # =========================================================================== # _strip_upload_mentions_from_memory # ===========================================================================