mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat: switch memory updater to async LLM calls (#2138)
* docs: mark memory updater async migration as completed - Update TODO.md to mark the replacement of sync model.invoke() with async model.ainvoke() in title_middleware and memory updater as completed using [x] format Addresses #2131 * feat: switch memory updater to async LLM calls - Add async aupdate_memory() method using await model.ainvoke() - Convert sync update_memory() to use async wrapper - Add _run_async_update_sync() for nested loop context handling - Maintain backward compatibility with existing sync API - Add ThreadPoolExecutor for async execution from sync contexts Addresses #2131 * test: add tests for async memory updater - Add test_async_update_memory_uses_ainvoke() to verify async path - Convert existing tests to use AsyncMock and ainvoke assertions - Add test_sync_update_memory_wrapper_works_in_running_loop() - Update all model mocks to use async await patterns Addresses #2131 * fix: apply ruff formatting to memory updater - Format multi-line expressions to single line - Ensure code style consistency with project standards - Fix lint issues caught by GitHub Actions * test: add comprehensive tests for async memory updater - Add test_async_update_memory_uses_ainvoke() to verify async path - Convert existing tests to use AsyncMock and ainvoke assertions - Add test_sync_update_memory_wrapper_works_in_running_loop() - Update all model mocks to use async await patterns - Ensure backward compatibility with sync API * fix: satisfy ruff formatting in memory updater test --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
55bc09ac33
commit
07fc25d285
@ -24,7 +24,7 @@
|
||||
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
|
||||
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
|
||||
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
|
||||
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
|
||||
- [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
|
||||
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
|
||||
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
|
||||
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
"""Memory updater for reading, writing, and updating memory data."""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
from deerflow.agents.memory.prompt import (
|
||||
@ -21,6 +25,12 @@ from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=4,
|
||||
thread_name_prefix="memory-updater-sync",
|
||||
)
|
||||
atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False))
|
||||
|
||||
|
||||
def _create_empty_memory() -> dict[str, Any]:
|
||||
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
|
||||
@ -206,6 +216,39 @@ def _extract_text(content: Any) -> str:
|
||||
return str(content)
|
||||
|
||||
|
||||
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
|
||||
"""Run an async memory update from sync code, including nested-loop contexts."""
|
||||
handed_off = False
|
||||
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None and loop.is_running():
|
||||
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
|
||||
handed_off = True
|
||||
return future.result()
|
||||
|
||||
handed_off = True
|
||||
return asyncio.run(coro)
|
||||
except Exception:
|
||||
if not handed_off:
|
||||
close = getattr(coro, "close", None)
|
||||
if callable(close):
|
||||
try:
|
||||
close()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to close un-awaited memory update coroutine",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.exception("Failed to run async memory update from sync context")
|
||||
return False
|
||||
|
||||
|
||||
# Matches sentences that describe a file-upload *event* rather than general
|
||||
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
||||
# such as "User works with CSV files" or "prefers PDF export".
|
||||
@ -269,6 +312,113 @@ class MemoryUpdater:
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
def _build_correction_hint(
|
||||
self,
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
) -> str:
|
||||
"""Build optional prompt hints for correction and reinforcement signals."""
|
||||
correction_hint = ""
|
||||
if correction_detected:
|
||||
correction_hint = (
|
||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
if reinforcement_detected:
|
||||
reinforcement_hint = (
|
||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||
"Record the confirmed approach, style, or preference as a fact with category "
|
||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||
)
|
||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||
|
||||
return correction_hint
|
||||
|
||||
def _prepare_update_prompt(
|
||||
self,
|
||||
messages: list[Any],
|
||||
agent_name: str | None,
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
) -> tuple[dict[str, Any], str] | None:
|
||||
"""Load memory and build the update prompt for a conversation."""
|
||||
config = get_memory_config()
|
||||
if not config.enabled or not messages:
|
||||
return None
|
||||
|
||||
current_memory = get_memory_data(agent_name)
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
if not conversation_text.strip():
|
||||
return None
|
||||
|
||||
correction_hint = self._build_correction_hint(
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
correction_hint=correction_hint,
|
||||
)
|
||||
return current_memory, prompt
|
||||
|
||||
def _finalize_update(
|
||||
self,
|
||||
current_memory: dict[str, Any],
|
||||
response_content: Any,
|
||||
thread_id: str | None,
|
||||
agent_name: str | None,
|
||||
) -> bool:
|
||||
"""Parse the model response, apply updates, and persist memory."""
|
||||
response_text = _extract_text(response_content).strip()
|
||||
|
||||
if response_text.startswith("```"):
|
||||
lines = response_text.split("\n")
|
||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
update_data = json.loads(response_text)
|
||||
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
return get_memory_storage().save(updated_memory, agent_name)
|
||||
|
||||
async def aupdate_memory(
|
||||
self,
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory asynchronously based on conversation messages."""
|
||||
try:
|
||||
prepared = self._prepare_update_prompt(
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
if prepared is None:
|
||||
return False
|
||||
|
||||
current_memory, prompt = prepared
|
||||
model = self._get_model()
|
||||
response = await model.ainvoke(prompt)
|
||||
return self._finalize_update(
|
||||
current_memory=current_memory,
|
||||
response_content=response.content,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("Memory update failed: %s", e)
|
||||
return False
|
||||
|
||||
def update_memory(
|
||||
self,
|
||||
messages: list[Any],
|
||||
@ -277,7 +427,7 @@ class MemoryUpdater:
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
"""Synchronously update memory via the async updater path.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
@ -289,78 +439,15 @@ class MemoryUpdater:
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get current memory
|
||||
current_memory = get_memory_data(agent_name)
|
||||
|
||||
# Format conversation for prompt
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
|
||||
if not conversation_text.strip():
|
||||
return False
|
||||
|
||||
# Build prompt
|
||||
correction_hint = ""
|
||||
if correction_detected:
|
||||
correction_hint = (
|
||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
if reinforcement_detected:
|
||||
reinforcement_hint = (
|
||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||
"Record the confirmed approach, style, or preference as a fact with category "
|
||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||
)
|
||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
correction_hint=correction_hint,
|
||||
return _run_async_update_sync(
|
||||
self.aupdate_memory(
|
||||
messages=messages,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
model = self._get_model()
|
||||
response = model.invoke(prompt)
|
||||
response_text = _extract_text(response.content).strip()
|
||||
|
||||
# Parse response
|
||||
# Remove markdown code blocks if present
|
||||
if response_text.startswith("```"):
|
||||
lines = response_text.split("\n")
|
||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
update_data = json.loads(response_text)
|
||||
|
||||
# Apply updates
|
||||
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||
|
||||
# Strip file-upload mentions from all summaries before saving.
|
||||
# Uploaded files are session-scoped and won't exist in future sessions,
|
||||
# so recording upload events in long-term memory causes the agent to
|
||||
# try (and fail) to locate those files in subsequent conversations.
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
|
||||
# Save
|
||||
return get_memory_storage().save(updated_memory, agent_name)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("Memory update failed: %s", e)
|
||||
return False
|
||||
)
|
||||
|
||||
def _apply_updates(
|
||||
self,
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.memory.prompt import format_conversation_for_update
|
||||
from deerflow.agents.memory.updater import (
|
||||
MemoryUpdater,
|
||||
_extract_text,
|
||||
_run_async_update_sync,
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
delete_memory_fact,
|
||||
@ -523,15 +527,16 @@ class TestUpdateMemoryStructuredResponse:
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = content
|
||||
model.invoke.return_value = response
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
return model
|
||||
|
||||
def test_string_response_parses(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
@ -546,6 +551,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_list_content_response_parses(self):
|
||||
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||
@ -570,6 +576,29 @@ class TestUpdateMemoryStructuredResponse:
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_async_update_memory_uses_ainvoke(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi there"
|
||||
ai_msg.tool_calls = []
|
||||
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_correction_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
@ -592,7 +621,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
|
||||
def test_correction_hint_empty_when_not_detected(self):
|
||||
@ -617,9 +646,89 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
def test_sync_update_memory_wrapper_works_in_running_loop(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello from loop"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
async def run_in_loop():
|
||||
return updater.update_memory([msg, ai_msg])
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
|
||||
updater = MemoryUpdater()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||
side_effect=RuntimeError("executor down"),
|
||||
),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello from loop"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
async def run_in_loop():
|
||||
return updater.update_memory([msg, ai_msg])
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRunAsyncUpdateSync:
|
||||
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
|
||||
class CloseableAwaitable:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def __await__(self):
|
||||
pytest.fail("awaitable should not have been awaited")
|
||||
yield
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
awaitable = CloseableAwaitable()
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||
side_effect=RuntimeError("executor down"),
|
||||
):
|
||||
|
||||
async def run_in_loop():
|
||||
return _run_async_update_sync(awaitable)
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is False
|
||||
assert awaitable.closed is True
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
"""Tests that fact deduplication is case-insensitive."""
|
||||
@ -694,7 +803,7 @@ class TestReinforcementHint:
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.invoke.return_value = response
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
@ -719,7 +828,7 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
@ -744,7 +853,7 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
@ -769,6 +878,6 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user