mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-01 14:28:28 +00:00
fix(memory): replace short-lived asyncio.run() with persistent event loop (#2627)
* fix(memory): replace short-lived asyncio.run() with persistent event loop to prevent zombie httpx connections The memory updater used asyncio.run() inside daemon threads, creating and destroying short-lived event loops on every update. Langchain providers (e.g. langchain-anthropic) cache httpx AsyncClient instances globally via @lru_cache, so SSL connections created on a loop that is subsequently destroyed become zombie connections in the shared pool. When the main agent's lead run later reuses one of these connections, httpx/anyio triggers RuntimeError: Event loop is closed during connection cleanup. Replace the ThreadPoolExecutor + asyncio.run() pattern with a _MemoryLoopRunner that maintains a single persistent event loop in a daemon thread for the process lifetime. Since the loop never closes, connections bound to it never become invalid. The _run_async_update_sync function now submits coroutines to this persistent loop via run_coroutine_threadsafe instead of creating throwaway loops. * update the code to address the review comments * Fix the review comments of 2615 P1 — user_id forwarded through sync path: Added user_id parameter to _prepare_update_prompt, _finalize_update, and _do_update_memory_sync, and forwarded it to get_memory_data(agent_name, user_id=user_id) and save(..., user_id=user_id). The update_memory() entry point now passes user_id through both the executor.submit path and the direct call path. Added TestUserIdForwarding with two regression tests (sync + async) verifying get_memory_data and save receive the correct user_id. P2 — aupdate_memory() delegates to sync: Replaced the model.ainvoke() call with asyncio.to_thread(self._do_update_memory_sync, ...). This eliminates the unsafe async provider client path entirely — all memory updater entry points now use the isolated sync model.invoke() path. Updated the test from asserting ainvoke is awaited to asserting invoke is called and ainvoke is not. Nit — duplicate comment removed: Removed the duplicated # Matches sentences... comment on line 230. * Chore(test): update the code of test_memory_updater --------- Co-authored-by: rayhpeng <rayhpeng@gmail.com>
This commit is contained in:
parent
7dea1666ce
commit
c0da278269
@ -9,7 +9,6 @@ import logging
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
from deerflow.agents.memory.prompt import (
|
||||
@ -26,6 +25,12 @@ from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Thread pool for offloading sync memory updates when called from an async
|
||||
# context. Unlike the previous asyncio.run() approach, this runs *sync*
|
||||
# model.invoke() calls — no event loop is created, so the langchain async
|
||||
# httpx client pool (globally cached via @lru_cache) is never touched and
|
||||
# cross-loop connection reuse is impossible.
|
||||
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=4,
|
||||
thread_name_prefix="memory-updater-sync",
|
||||
@ -222,39 +227,6 @@ def _extract_text(content: Any) -> str:
|
||||
return str(content)
|
||||
|
||||
|
||||
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
|
||||
"""Run an async memory update from sync code, including nested-loop contexts."""
|
||||
handed_off = False
|
||||
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None and loop.is_running():
|
||||
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
|
||||
handed_off = True
|
||||
return future.result()
|
||||
|
||||
handed_off = True
|
||||
return asyncio.run(coro)
|
||||
except Exception:
|
||||
if not handed_off:
|
||||
close = getattr(coro, "close", None)
|
||||
if callable(close):
|
||||
try:
|
||||
close()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to close un-awaited memory update coroutine",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.exception("Failed to run async memory update from sync context")
|
||||
return False
|
||||
|
||||
|
||||
# Matches sentences that describe a file-upload *event* rather than general
|
||||
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
||||
# such as "User works with CSV files" or "prefers PDF export".
|
||||
@ -349,13 +321,14 @@ class MemoryUpdater:
|
||||
agent_name: str | None,
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[dict[str, Any], str] | None:
|
||||
"""Load memory and build the update prompt for a conversation."""
|
||||
config = get_memory_config()
|
||||
if not config.enabled or not messages:
|
||||
return None
|
||||
|
||||
current_memory = get_memory_data(agent_name)
|
||||
current_memory = get_memory_data(agent_name, user_id=user_id)
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
if not conversation_text.strip():
|
||||
return None
|
||||
@ -377,6 +350,7 @@ class MemoryUpdater:
|
||||
response_content: Any,
|
||||
thread_id: str | None,
|
||||
agent_name: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Parse the model response, apply updates, and persist memory."""
|
||||
response_text = _extract_text(response_content).strip()
|
||||
@ -390,7 +364,7 @@ class MemoryUpdater:
|
||||
# cannot corrupt the still-cached original object reference.
|
||||
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
return get_memory_storage().save(updated_memory, agent_name)
|
||||
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
||||
|
||||
async def aupdate_memory(
|
||||
self,
|
||||
@ -399,28 +373,63 @@ class MemoryUpdater:
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Update memory asynchronously based on conversation messages."""
|
||||
"""Update memory asynchronously by delegating to the sync path.
|
||||
|
||||
Uses ``asyncio.to_thread`` to run the *sync* ``model.invoke()`` path
|
||||
in a worker thread so no second event loop is created and the
|
||||
langchain async httpx client pool (shared with the lead agent) is
|
||||
never touched. This eliminates the cross-loop connection-reuse bug
|
||||
described in issue #2615.
|
||||
"""
|
||||
return await asyncio.to_thread(
|
||||
self._do_update_memory_sync,
|
||||
messages=messages,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def _do_update_memory_sync(
|
||||
self,
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Pure-sync memory update using ``model.invoke()``.
|
||||
|
||||
Uses the *sync* LLM call path so no event loop is created. This
|
||||
guarantees that the langchain provider's globally cached async
|
||||
httpx ``AsyncClient`` / connection pool (the one shared with the
|
||||
lead agent) is never touched — no cross-loop connection reuse is
|
||||
possible.
|
||||
"""
|
||||
try:
|
||||
prepared = await asyncio.to_thread(
|
||||
self._prepare_update_prompt,
|
||||
prepared = self._prepare_update_prompt(
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
user_id=user_id,
|
||||
)
|
||||
if prepared is None:
|
||||
return False
|
||||
|
||||
current_memory, prompt = prepared
|
||||
model = self._get_model()
|
||||
response = await model.ainvoke(prompt, config={"run_name": "memory_agent"})
|
||||
return await asyncio.to_thread(
|
||||
self._finalize_update,
|
||||
response = model.invoke(prompt, config={"run_name": "memory_agent"})
|
||||
return self._finalize_update(
|
||||
current_memory=current_memory,
|
||||
response_content=response.content,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
@ -438,7 +447,16 @@ class MemoryUpdater:
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Synchronously update memory via the async updater path.
|
||||
"""Synchronously update memory using the sync LLM path.
|
||||
|
||||
Uses ``model.invoke()`` (sync HTTP) which operates on a completely
|
||||
separate connection pool from the async ``AsyncClient`` shared by
|
||||
the lead agent. This eliminates the cross-loop connection-reuse
|
||||
bug described in issue #2615.
|
||||
|
||||
When called from within a running event loop (e.g. from a LangGraph
|
||||
node), the blocking sync call is offloaded to a thread pool so the
|
||||
caller's loop is not blocked.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
@ -451,14 +469,34 @@ class MemoryUpdater:
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
"""
|
||||
return _run_async_update_sync(
|
||||
self.aupdate_memory(
|
||||
messages=messages,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(
|
||||
self._do_update_memory_sync,
|
||||
messages=messages,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
user_id=user_id,
|
||||
)
|
||||
return future.result()
|
||||
except Exception:
|
||||
logger.exception("Failed to offload memory update to executor")
|
||||
return False
|
||||
|
||||
return self._do_update_memory_sync(
|
||||
messages=messages,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def _apply_updates(
|
||||
|
||||
@ -1,13 +1,10 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.memory.prompt import format_conversation_for_update
|
||||
from deerflow.agents.memory.updater import (
|
||||
MemoryUpdater,
|
||||
_extract_text,
|
||||
_run_async_update_sync,
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
delete_memory_fact,
|
||||
@ -528,6 +525,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
response = MagicMock()
|
||||
response.content = content
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
return model
|
||||
|
||||
def test_string_response_parses(self):
|
||||
@ -551,7 +549,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
model.invoke.assert_called_once()
|
||||
|
||||
def test_list_content_response_parses(self):
|
||||
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||
@ -576,7 +574,8 @@ class TestUpdateMemoryStructuredResponse:
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_async_update_memory_uses_ainvoke(self):
|
||||
def test_async_update_memory_delegates_to_sync(self):
|
||||
"""aupdate_memory should delegate to sync _do_update_memory_sync via to_thread."""
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
@ -597,8 +596,9 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "memory_agent"}
|
||||
# aupdate_memory delegates to sync path — model.invoke, not ainvoke
|
||||
model.invoke.assert_called_once()
|
||||
model.ainvoke.assert_not_called()
|
||||
|
||||
def test_correction_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
@ -622,7 +622,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
|
||||
def test_correction_hint_empty_when_not_detected(self):
|
||||
@ -647,7 +647,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
def test_sync_update_memory_wrapper_works_in_running_loop(self):
|
||||
@ -675,9 +675,9 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
model.invoke.assert_called_once()
|
||||
|
||||
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
|
||||
def test_sync_update_memory_returns_false_when_executor_down(self):
|
||||
updater = MemoryUpdater()
|
||||
|
||||
with (
|
||||
@ -702,33 +702,67 @@ class TestUpdateMemoryStructuredResponse:
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRunAsyncUpdateSync:
|
||||
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
|
||||
class CloseableAwaitable:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
class TestSyncUpdateIsolatesProviderClientPool:
|
||||
"""Regression tests for issue #2615.
|
||||
|
||||
def __await__(self):
|
||||
pytest.fail("awaitable should not have been awaited")
|
||||
yield
|
||||
The sync ``update_memory`` path must use ``model.invoke()`` (sync HTTP)
|
||||
and never touch the async provider client pool shared with the lead agent.
|
||||
"""
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
def test_sync_update_uses_invoke_not_ainvoke(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = valid_json
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
|
||||
awaitable = CloseableAwaitable()
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||
side_effect=RuntimeError("executor down"),
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
async def run_in_loop():
|
||||
return _run_async_update_sync(awaitable)
|
||||
assert result is True
|
||||
model.invoke.assert_called_once()
|
||||
model.ainvoke.assert_not_called()
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
def test_no_event_loop_created_during_sync_update(self):
|
||||
"""Sync update must not create or destroy any event loop."""
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = valid_json
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
|
||||
assert result is False
|
||||
assert awaitable.closed is True
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
patch("asyncio.run", side_effect=AssertionError("asyncio.run must not be called from sync update path")),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
@ -805,6 +839,7 @@ class TestReinforcementHint:
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
@ -829,7 +864,7 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
@ -854,7 +889,7 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
@ -879,7 +914,7 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
@ -908,11 +943,11 @@ class TestFinalizeCacheIsolation:
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = new_fact_json
|
||||
mock_model = AsyncMock()
|
||||
mock_model.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_model = MagicMock()
|
||||
mock_model.invoke = MagicMock(return_value=mock_response)
|
||||
|
||||
saved_objects: list[dict] = []
|
||||
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
|
||||
save_mock = MagicMock(side_effect=lambda m, a=None, **_: saved_objects.append(m) or False) # always fails
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=mock_model),
|
||||
@ -929,6 +964,85 @@ class TestFinalizeCacheIsolation:
|
||||
ai_msg.tool_calls = []
|
||||
updater.update_memory([msg, ai_msg], thread_id="t1")
|
||||
|
||||
# save_mock must have been exercised — otherwise the deepcopy-on-save-failure path isn't covered
|
||||
save_mock.assert_called_once()
|
||||
assert len(saved_objects) == 1, "save must have been called with the updated memory object"
|
||||
|
||||
# original_memory must not have been mutated — deepcopy isolates the mutation
|
||||
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
|
||||
assert original_memory["facts"][0]["content"] == "original"
|
||||
|
||||
|
||||
class TestUserIdForwarding:
|
||||
"""Regression: user_id must flow through the entire sync update path.
|
||||
|
||||
When MemoryUpdateQueue captures context.user_id and passes it into
|
||||
update_memory(..., user_id=context.user_id), the sync path must forward
|
||||
it into _prepare_update_prompt → get_memory_data() and
|
||||
_finalize_update → save(), so per-user memory isolation is maintained.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_model(content):
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = content
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
return model
|
||||
|
||||
def test_sync_update_forwards_user_id_to_load_and_save(self):
|
||||
"""update_memory must pass user_id to get_memory_data and storage.save."""
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()) as mock_load,
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
result = updater.update_memory([msg, ai_msg], user_id="user-42")
|
||||
|
||||
assert result is True
|
||||
mock_load.assert_called_once_with(None, user_id="user-42")
|
||||
mock_storage.save.assert_called_once()
|
||||
save_call = mock_storage.save.call_args
|
||||
assert save_call.kwargs.get("user_id") == "user-42" or (len(save_call.args) > 2 and save_call.args[2] == "user-42")
|
||||
|
||||
def test_async_update_forwards_user_id_to_load_and_save(self):
|
||||
"""aupdate_memory must pass user_id through to the sync delegate."""
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()) as mock_load,
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
result = asyncio.run(updater.aupdate_memory([msg, ai_msg], user_id="user-99"))
|
||||
|
||||
assert result is True
|
||||
mock_load.assert_called_once_with(None, user_id="user-99")
|
||||
save_call = mock_storage.save.call_args
|
||||
assert save_call.kwargs.get("user_id") == "user-99" or (len(save_call.args) > 2 and save_call.args[2] == "user-99")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user