feat: switch memory updater to async LLM calls (#2138)

* docs: mark memory updater async migration as completed

- Update TODO.md to mark the replacement of sync model.invoke()
  with async model.ainvoke() in title_middleware and memory updater
  as completed using [x] format

Addresses #2131

* feat: switch memory updater to async LLM calls

- Add async aupdate_memory() method using await model.ainvoke()
- Convert sync update_memory() to use async wrapper
- Add _run_async_update_sync() for nested loop context handling
- Maintain backward compatibility with existing sync API
- Add ThreadPoolExecutor for async execution from sync contexts

Addresses #2131

* test: add tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns

Addresses #2131

* fix: apply ruff formatting to memory updater

- Format multi-line expressions to single line
- Ensure code style consistency with project standards
- Fix lint issues caught by GitHub Actions

* test: add comprehensive tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns
- Ensure backward compatibility with sync API

* fix: satisfy ruff formatting in memory updater test

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
luo jiyin 2026-04-14 11:10:42 +08:00 committed by GitHub
parent 55bc09ac33
commit 07fc25d285
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 278 additions and 82 deletions

View File

@ -24,7 +24,7 @@
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)

View File

@ -1,10 +1,14 @@
"""Memory updater for reading, writing, and updating memory data."""
import asyncio
import atexit
import concurrent.futures
import json
import logging
import math
import re
import uuid
from collections.abc import Awaitable
from typing import Any
from deerflow.agents.memory.prompt import (
@ -21,6 +25,12 @@ from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
max_workers=4,
thread_name_prefix="memory-updater-sync",
)
atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False))
def _create_empty_memory() -> dict[str, Any]:
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
@ -206,6 +216,39 @@ def _extract_text(content: Any) -> str:
return str(content)
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
"""Run an async memory update from sync code, including nested-loop contexts."""
handed_off = False
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
handed_off = True
return future.result()
handed_off = True
return asyncio.run(coro)
except Exception:
if not handed_off:
close = getattr(coro, "close", None)
if callable(close):
try:
close()
except Exception:
logger.debug(
"Failed to close un-awaited memory update coroutine",
exc_info=True,
)
logger.exception("Failed to run async memory update from sync context")
return False
# Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export".
@ -269,6 +312,113 @@ class MemoryUpdater:
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
def _build_correction_hint(
self,
correction_detected: bool,
reinforcement_detected: bool,
) -> str:
"""Build optional prompt hints for correction and reinforcement signals."""
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
return correction_hint
def _prepare_update_prompt(
self,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation."""
config = get_memory_config()
if not config.enabled or not messages:
return None
current_memory = get_memory_data(agent_name)
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return None
correction_hint = self._build_correction_hint(
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
return current_memory, prompt
def _finalize_update(
self,
current_memory: dict[str, Any],
response_content: Any,
thread_id: str | None,
agent_name: str | None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name)
async def aupdate_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory asynchronously based on conversation messages."""
try:
prepared = self._prepare_update_prompt(
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
if prepared is None:
return False
current_memory, prompt = prepared
model = self._get_model()
response = await model.ainvoke(prompt)
return self._finalize_update(
current_memory=current_memory,
response_content=response.content,
thread_id=thread_id,
agent_name=agent_name,
)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def update_memory(
self,
messages: list[Any],
@ -277,7 +427,7 @@ class MemoryUpdater:
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages.
"""Synchronously update memory via the async updater path.
Args:
messages: List of conversation messages.
@ -289,78 +439,15 @@ class MemoryUpdater:
Returns:
True if update was successful, False otherwise.
"""
config = get_memory_config()
if not config.enabled:
return False
if not messages:
return False
try:
# Get current memory
current_memory = get_memory_data(agent_name)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
return _run_async_update_sync(
self.aupdate_memory(
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = _extract_text(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage().save(updated_memory, agent_name)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
)
def _apply_updates(
self,

View File

@ -1,9 +1,13 @@
from unittest.mock import MagicMock, patch
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from deerflow.agents.memory.prompt import format_conversation_for_update
from deerflow.agents.memory.updater import (
MemoryUpdater,
_extract_text,
_run_async_update_sync,
clear_memory_data,
create_memory_fact,
delete_memory_fact,
@ -523,15 +527,16 @@ class TestUpdateMemoryStructuredResponse:
model = MagicMock()
response = MagicMock()
response.content = content
model.invoke.return_value = response
model.ainvoke = AsyncMock(return_value=response)
return model
def test_string_response_parses(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
@ -546,6 +551,7 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg])
assert result is True
model.ainvoke.assert_awaited_once()
def test_list_content_response_parses(self):
"""LLM response as list-of-blocks should be extracted, not repr'd."""
@ -570,6 +576,29 @@ class TestUpdateMemoryStructuredResponse:
assert result is True
def test_async_update_memory_uses_ainvoke(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi there"
ai_msg.tool_calls = []
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
assert result is True
model.ainvoke.assert_awaited_once()
def test_correction_hint_injected_when_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
@ -592,7 +621,7 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg], correction_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" in prompt
def test_correction_hint_empty_when_not_detected(self):
@ -617,9 +646,89 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg], correction_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" not in prompt
def test_sync_update_memory_wrapper_works_in_running_loop(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello from loop"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi"
ai_msg.tool_calls = []
async def run_in_loop():
return updater.update_memory([msg, ai_msg])
result = asyncio.run(run_in_loop())
assert result is True
model.ainvoke.assert_awaited_once()
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
updater = MemoryUpdater()
with (
patch(
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
side_effect=RuntimeError("executor down"),
),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello from loop"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi"
ai_msg.tool_calls = []
async def run_in_loop():
return updater.update_memory([msg, ai_msg])
result = asyncio.run(run_in_loop())
assert result is False
class TestRunAsyncUpdateSync:
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
class CloseableAwaitable:
def __init__(self):
self.closed = False
def __await__(self):
pytest.fail("awaitable should not have been awaited")
yield
def close(self):
self.closed = True
awaitable = CloseableAwaitable()
with patch(
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
side_effect=RuntimeError("executor down"),
):
async def run_in_loop():
return _run_async_update_sync(awaitable)
result = asyncio.run(run_in_loop())
assert result is False
assert awaitable.closed is True
class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive."""
@ -694,7 +803,7 @@ class TestReinforcementHint:
model = MagicMock()
response = MagicMock()
response.content = f"```json\n{json_response}\n```"
model.invoke.return_value = response
model.ainvoke = AsyncMock(return_value=response)
return model
def test_reinforcement_hint_injected_when_detected(self):
@ -719,7 +828,7 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self):
@ -744,7 +853,7 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self):
@ -769,6 +878,6 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt