From 07fc25d2857ea25b46dc635c9059eba8f8ce6dfe Mon Sep 17 00:00:00 2001 From: luo jiyin Date: Tue, 14 Apr 2026 11:10:42 +0800 Subject: [PATCH] feat: switch memory updater to async LLM calls (#2138) * docs: mark memory updater async migration as completed - Update TODO.md to mark the replacement of sync model.invoke() with async model.ainvoke() in title_middleware and memory updater as completed using [x] format Addresses #2131 * feat: switch memory updater to async LLM calls - Add async aupdate_memory() method using await model.ainvoke() - Convert sync update_memory() to use async wrapper - Add _run_async_update_sync() for nested loop context handling - Maintain backward compatibility with existing sync API - Add ThreadPoolExecutor for async execution from sync contexts Addresses #2131 * test: add tests for async memory updater - Add test_async_update_memory_uses_ainvoke() to verify async path - Convert existing tests to use AsyncMock and ainvoke assertions - Add test_sync_update_memory_wrapper_works_in_running_loop() - Update all model mocks to use async await patterns Addresses #2131 * fix: apply ruff formatting to memory updater - Format multi-line expressions to single line - Ensure code style consistency with project standards - Fix lint issues caught by GitHub Actions * test: add comprehensive tests for async memory updater - Add test_async_update_memory_uses_ainvoke() to verify async path - Convert existing tests to use AsyncMock and ainvoke assertions - Add test_sync_update_memory_wrapper_works_in_running_loop() - Update all model mocks to use async await patterns - Ensure backward compatibility with sync API * fix: satisfy ruff formatting in memory updater test --------- Co-authored-by: Willem Jiang --- backend/docs/TODO.md | 2 +- .../harness/deerflow/agents/memory/updater.py | 231 ++++++++++++------ backend/tests/test_memory_updater.py | 127 +++++++++- 3 files changed, 278 insertions(+), 82 deletions(-) diff --git a/backend/docs/TODO.md b/backend/docs/TODO.md index 52dc7867d..421e02ec7 100644 --- a/backend/docs/TODO.md +++ b/backend/docs/TODO.md @@ -24,7 +24,7 @@ - [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario) - [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py` - Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search) - - Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater + - [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater - Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O - For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker) diff --git a/backend/packages/harness/deerflow/agents/memory/updater.py b/backend/packages/harness/deerflow/agents/memory/updater.py index d1f124d4c..1ef32fb60 100644 --- a/backend/packages/harness/deerflow/agents/memory/updater.py +++ b/backend/packages/harness/deerflow/agents/memory/updater.py @@ -1,10 +1,14 @@ """Memory updater for reading, writing, and updating memory data.""" +import asyncio +import atexit +import concurrent.futures import json import logging import math import re import uuid +from collections.abc import Awaitable from typing import Any from deerflow.agents.memory.prompt import ( @@ -21,6 +25,12 @@ from deerflow.models import create_chat_model logger = logging.getLogger(__name__) +_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=4, + thread_name_prefix="memory-updater-sync", +) +atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False)) + def _create_empty_memory() -> dict[str, Any]: """Backward-compatible wrapper around the storage-layer empty-memory factory.""" @@ -206,6 +216,39 @@ def _extract_text(content: Any) -> str: return str(content) +def _run_async_update_sync(coro: Awaitable[bool]) -> bool: + """Run an async memory update from sync code, including nested-loop contexts.""" + handed_off = False + + try: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None and loop.is_running(): + future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro) + handed_off = True + return future.result() + + handed_off = True + return asyncio.run(coro) + except Exception: + if not handed_off: + close = getattr(coro, "close", None) + if callable(close): + try: + close() + except Exception: + logger.debug( + "Failed to close un-awaited memory update coroutine", + exc_info=True, + ) + + logger.exception("Failed to run async memory update from sync context") + return False + + # Matches sentences that describe a file-upload *event* rather than general # file-related work. Deliberately narrow to avoid removing legitimate facts # such as "User works with CSV files" or "prefers PDF export". @@ -269,6 +312,113 @@ class MemoryUpdater: model_name = self._model_name or config.model_name return create_chat_model(name=model_name, thinking_enabled=False) + def _build_correction_hint( + self, + correction_detected: bool, + reinforcement_detected: bool, + ) -> str: + """Build optional prompt hints for correction and reinforcement signals.""" + correction_hint = "" + if correction_detected: + correction_hint = ( + "IMPORTANT: Explicit correction signals were detected in this conversation. " + "Pay special attention to what the agent got wrong, what the user corrected, " + "and record the correct approach as a fact with category " + '"correction" and confidence >= 0.95 when appropriate.' + ) + if reinforcement_detected: + reinforcement_hint = ( + "IMPORTANT: Positive reinforcement signals were detected in this conversation. " + "The user explicitly confirmed the agent's approach was correct or helpful. " + "Record the confirmed approach, style, or preference as a fact with category " + '"preference" or "behavior" and confidence >= 0.9 when appropriate.' + ) + correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint + + return correction_hint + + def _prepare_update_prompt( + self, + messages: list[Any], + agent_name: str | None, + correction_detected: bool, + reinforcement_detected: bool, + ) -> tuple[dict[str, Any], str] | None: + """Load memory and build the update prompt for a conversation.""" + config = get_memory_config() + if not config.enabled or not messages: + return None + + current_memory = get_memory_data(agent_name) + conversation_text = format_conversation_for_update(messages) + if not conversation_text.strip(): + return None + + correction_hint = self._build_correction_hint( + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, + ) + prompt = MEMORY_UPDATE_PROMPT.format( + current_memory=json.dumps(current_memory, indent=2), + conversation=conversation_text, + correction_hint=correction_hint, + ) + return current_memory, prompt + + def _finalize_update( + self, + current_memory: dict[str, Any], + response_content: Any, + thread_id: str | None, + agent_name: str | None, + ) -> bool: + """Parse the model response, apply updates, and persist memory.""" + response_text = _extract_text(response_content).strip() + + if response_text.startswith("```"): + lines = response_text.split("\n") + response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) + + update_data = json.loads(response_text) + updated_memory = self._apply_updates(current_memory, update_data, thread_id) + updated_memory = _strip_upload_mentions_from_memory(updated_memory) + return get_memory_storage().save(updated_memory, agent_name) + + async def aupdate_memory( + self, + messages: list[Any], + thread_id: str | None = None, + agent_name: str | None = None, + correction_detected: bool = False, + reinforcement_detected: bool = False, + ) -> bool: + """Update memory asynchronously based on conversation messages.""" + try: + prepared = self._prepare_update_prompt( + messages=messages, + agent_name=agent_name, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, + ) + if prepared is None: + return False + + current_memory, prompt = prepared + model = self._get_model() + response = await model.ainvoke(prompt) + return self._finalize_update( + current_memory=current_memory, + response_content=response.content, + thread_id=thread_id, + agent_name=agent_name, + ) + except json.JSONDecodeError as e: + logger.warning("Failed to parse LLM response for memory update: %s", e) + return False + except Exception as e: + logger.exception("Memory update failed: %s", e) + return False + def update_memory( self, messages: list[Any], @@ -277,7 +427,7 @@ class MemoryUpdater: correction_detected: bool = False, reinforcement_detected: bool = False, ) -> bool: - """Update memory based on conversation messages. + """Synchronously update memory via the async updater path. Args: messages: List of conversation messages. @@ -289,78 +439,15 @@ class MemoryUpdater: Returns: True if update was successful, False otherwise. """ - config = get_memory_config() - if not config.enabled: - return False - - if not messages: - return False - - try: - # Get current memory - current_memory = get_memory_data(agent_name) - - # Format conversation for prompt - conversation_text = format_conversation_for_update(messages) - - if not conversation_text.strip(): - return False - - # Build prompt - correction_hint = "" - if correction_detected: - correction_hint = ( - "IMPORTANT: Explicit correction signals were detected in this conversation. " - "Pay special attention to what the agent got wrong, what the user corrected, " - "and record the correct approach as a fact with category " - '"correction" and confidence >= 0.95 when appropriate.' - ) - if reinforcement_detected: - reinforcement_hint = ( - "IMPORTANT: Positive reinforcement signals were detected in this conversation. " - "The user explicitly confirmed the agent's approach was correct or helpful. " - "Record the confirmed approach, style, or preference as a fact with category " - '"preference" or "behavior" and confidence >= 0.9 when appropriate.' - ) - correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint - - prompt = MEMORY_UPDATE_PROMPT.format( - current_memory=json.dumps(current_memory, indent=2), - conversation=conversation_text, - correction_hint=correction_hint, + return _run_async_update_sync( + self.aupdate_memory( + messages=messages, + thread_id=thread_id, + agent_name=agent_name, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, ) - - # Call LLM - model = self._get_model() - response = model.invoke(prompt) - response_text = _extract_text(response.content).strip() - - # Parse response - # Remove markdown code blocks if present - if response_text.startswith("```"): - lines = response_text.split("\n") - response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) - - update_data = json.loads(response_text) - - # Apply updates - updated_memory = self._apply_updates(current_memory, update_data, thread_id) - - # Strip file-upload mentions from all summaries before saving. - # Uploaded files are session-scoped and won't exist in future sessions, - # so recording upload events in long-term memory causes the agent to - # try (and fail) to locate those files in subsequent conversations. - updated_memory = _strip_upload_mentions_from_memory(updated_memory) - - # Save - return get_memory_storage().save(updated_memory, agent_name) - - except json.JSONDecodeError as e: - logger.warning("Failed to parse LLM response for memory update: %s", e) - return False - except Exception as e: - logger.exception("Memory update failed: %s", e) - return False + ) def _apply_updates( self, diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index 48fdfd89e..9011f3ea9 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -1,9 +1,13 @@ -from unittest.mock import MagicMock, patch +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from deerflow.agents.memory.prompt import format_conversation_for_update from deerflow.agents.memory.updater import ( MemoryUpdater, _extract_text, + _run_async_update_sync, clear_memory_data, create_memory_fact, delete_memory_fact, @@ -523,15 +527,16 @@ class TestUpdateMemoryStructuredResponse: model = MagicMock() response = MagicMock() response.content = content - model.invoke.return_value = response + model.ainvoke = AsyncMock(return_value=response) return model def test_string_response_parses(self): updater = MemoryUpdater() valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) with ( - patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)), + patch.object(updater, "_get_model", return_value=model), patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), @@ -546,6 +551,7 @@ class TestUpdateMemoryStructuredResponse: result = updater.update_memory([msg, ai_msg]) assert result is True + model.ainvoke.assert_awaited_once() def test_list_content_response_parses(self): """LLM response as list-of-blocks should be extracted, not repr'd.""" @@ -570,6 +576,29 @@ class TestUpdateMemoryStructuredResponse: assert result is True + def test_async_update_memory_uses_ainvoke(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Hello" + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Hi there" + ai_msg.tool_calls = [] + result = asyncio.run(updater.aupdate_memory([msg, ai_msg])) + + assert result is True + model.ainvoke.assert_awaited_once() + def test_correction_hint_injected_when_detected(self): updater = MemoryUpdater() valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' @@ -592,7 +621,7 @@ class TestUpdateMemoryStructuredResponse: result = updater.update_memory([msg, ai_msg], correction_detected=True) assert result is True - prompt = model.invoke.call_args[0][0] + prompt = model.ainvoke.await_args.args[0] assert "Explicit correction signals were detected" in prompt def test_correction_hint_empty_when_not_detected(self): @@ -617,9 +646,89 @@ class TestUpdateMemoryStructuredResponse: result = updater.update_memory([msg, ai_msg], correction_detected=False) assert result is True - prompt = model.invoke.call_args[0][0] + prompt = model.ainvoke.await_args.args[0] assert "Explicit correction signals were detected" not in prompt + def test_sync_update_memory_wrapper_works_in_running_loop(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Hello from loop" + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Hi" + ai_msg.tool_calls = [] + + async def run_in_loop(): + return updater.update_memory([msg, ai_msg]) + + result = asyncio.run(run_in_loop()) + + assert result is True + model.ainvoke.assert_awaited_once() + + def test_sync_update_memory_returns_false_when_bridge_submit_fails(self): + updater = MemoryUpdater() + + with ( + patch( + "deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit", + side_effect=RuntimeError("executor down"), + ), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Hello from loop" + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Hi" + ai_msg.tool_calls = [] + + async def run_in_loop(): + return updater.update_memory([msg, ai_msg]) + + result = asyncio.run(run_in_loop()) + + assert result is False + + +class TestRunAsyncUpdateSync: + def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self): + class CloseableAwaitable: + def __init__(self): + self.closed = False + + def __await__(self): + pytest.fail("awaitable should not have been awaited") + yield + + def close(self): + self.closed = True + + awaitable = CloseableAwaitable() + + with patch( + "deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit", + side_effect=RuntimeError("executor down"), + ): + + async def run_in_loop(): + return _run_async_update_sync(awaitable) + + result = asyncio.run(run_in_loop()) + + assert result is False + assert awaitable.closed is True + class TestFactDeduplicationCaseInsensitive: """Tests that fact deduplication is case-insensitive.""" @@ -694,7 +803,7 @@ class TestReinforcementHint: model = MagicMock() response = MagicMock() response.content = f"```json\n{json_response}\n```" - model.invoke.return_value = response + model.ainvoke = AsyncMock(return_value=response) return model def test_reinforcement_hint_injected_when_detected(self): @@ -719,7 +828,7 @@ class TestReinforcementHint: result = updater.update_memory([msg, ai_msg], reinforcement_detected=True) assert result is True - prompt = model.invoke.call_args[0][0] + prompt = model.ainvoke.await_args.args[0] assert "Positive reinforcement signals were detected" in prompt def test_reinforcement_hint_absent_when_not_detected(self): @@ -744,7 +853,7 @@ class TestReinforcementHint: result = updater.update_memory([msg, ai_msg], reinforcement_detected=False) assert result is True - prompt = model.invoke.call_args[0][0] + prompt = model.ainvoke.await_args.args[0] assert "Positive reinforcement signals were detected" not in prompt def test_both_hints_present_when_both_detected(self): @@ -769,6 +878,6 @@ class TestReinforcementHint: result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True) assert result is True - prompt = model.invoke.call_args[0][0] + prompt = model.ainvoke.await_args.args[0] assert "Explicit correction signals were detected" in prompt assert "Positive reinforcement signals were detected" in prompt