diff --git a/backend/packages/harness/deerflow/agents/memory/updater.py b/backend/packages/harness/deerflow/agents/memory/updater.py index 7d563fb87..6e55330a1 100644 --- a/backend/packages/harness/deerflow/agents/memory/updater.py +++ b/backend/packages/harness/deerflow/agents/memory/updater.py @@ -9,7 +9,6 @@ import logging import math import re import uuid -from collections.abc import Awaitable from typing import Any from deerflow.agents.memory.prompt import ( @@ -26,6 +25,12 @@ from deerflow.models import create_chat_model logger = logging.getLogger(__name__) + +# Thread pool for offloading sync memory updates when called from an async +# context. Unlike the previous asyncio.run() approach, this runs *sync* +# model.invoke() calls — no event loop is created, so the langchain async +# httpx client pool (globally cached via @lru_cache) is never touched and +# cross-loop connection reuse is impossible. _SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor( max_workers=4, thread_name_prefix="memory-updater-sync", @@ -222,39 +227,6 @@ def _extract_text(content: Any) -> str: return str(content) -def _run_async_update_sync(coro: Awaitable[bool]) -> bool: - """Run an async memory update from sync code, including nested-loop contexts.""" - handed_off = False - - try: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop is not None and loop.is_running(): - future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro) - handed_off = True - return future.result() - - handed_off = True - return asyncio.run(coro) - except Exception: - if not handed_off: - close = getattr(coro, "close", None) - if callable(close): - try: - close() - except Exception: - logger.debug( - "Failed to close un-awaited memory update coroutine", - exc_info=True, - ) - - logger.exception("Failed to run async memory update from sync context") - return False - - # Matches sentences that describe a file-upload *event* rather than general # file-related work. Deliberately narrow to avoid removing legitimate facts # such as "User works with CSV files" or "prefers PDF export". @@ -349,13 +321,14 @@ class MemoryUpdater: agent_name: str | None, correction_detected: bool, reinforcement_detected: bool, + user_id: str | None = None, ) -> tuple[dict[str, Any], str] | None: """Load memory and build the update prompt for a conversation.""" config = get_memory_config() if not config.enabled or not messages: return None - current_memory = get_memory_data(agent_name) + current_memory = get_memory_data(agent_name, user_id=user_id) conversation_text = format_conversation_for_update(messages) if not conversation_text.strip(): return None @@ -377,6 +350,7 @@ class MemoryUpdater: response_content: Any, thread_id: str | None, agent_name: str | None, + user_id: str | None = None, ) -> bool: """Parse the model response, apply updates, and persist memory.""" response_text = _extract_text(response_content).strip() @@ -390,7 +364,7 @@ class MemoryUpdater: # cannot corrupt the still-cached original object reference. updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id) updated_memory = _strip_upload_mentions_from_memory(updated_memory) - return get_memory_storage().save(updated_memory, agent_name) + return get_memory_storage().save(updated_memory, agent_name, user_id=user_id) async def aupdate_memory( self, @@ -399,28 +373,63 @@ class MemoryUpdater: agent_name: str | None = None, correction_detected: bool = False, reinforcement_detected: bool = False, + user_id: str | None = None, ) -> bool: - """Update memory asynchronously based on conversation messages.""" + """Update memory asynchronously by delegating to the sync path. + + Uses ``asyncio.to_thread`` to run the *sync* ``model.invoke()`` path + in a worker thread so no second event loop is created and the + langchain async httpx client pool (shared with the lead agent) is + never touched. This eliminates the cross-loop connection-reuse bug + described in issue #2615. + """ + return await asyncio.to_thread( + self._do_update_memory_sync, + messages=messages, + thread_id=thread_id, + agent_name=agent_name, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, + user_id=user_id, + ) + + def _do_update_memory_sync( + self, + messages: list[Any], + thread_id: str | None = None, + agent_name: str | None = None, + correction_detected: bool = False, + reinforcement_detected: bool = False, + user_id: str | None = None, + ) -> bool: + """Pure-sync memory update using ``model.invoke()``. + + Uses the *sync* LLM call path so no event loop is created. This + guarantees that the langchain provider's globally cached async + httpx ``AsyncClient`` / connection pool (the one shared with the + lead agent) is never touched — no cross-loop connection reuse is + possible. + """ try: - prepared = await asyncio.to_thread( - self._prepare_update_prompt, + prepared = self._prepare_update_prompt( messages=messages, agent_name=agent_name, correction_detected=correction_detected, reinforcement_detected=reinforcement_detected, + user_id=user_id, ) if prepared is None: return False current_memory, prompt = prepared model = self._get_model() - response = await model.ainvoke(prompt, config={"run_name": "memory_agent"}) - return await asyncio.to_thread( - self._finalize_update, + response = model.invoke(prompt, config={"run_name": "memory_agent"}) + return self._finalize_update( current_memory=current_memory, response_content=response.content, thread_id=thread_id, agent_name=agent_name, + user_id=user_id, ) except json.JSONDecodeError as e: logger.warning("Failed to parse LLM response for memory update: %s", e) @@ -438,7 +447,16 @@ class MemoryUpdater: reinforcement_detected: bool = False, user_id: str | None = None, ) -> bool: - """Synchronously update memory via the async updater path. + """Synchronously update memory using the sync LLM path. + + Uses ``model.invoke()`` (sync HTTP) which operates on a completely + separate connection pool from the async ``AsyncClient`` shared by + the lead agent. This eliminates the cross-loop connection-reuse + bug described in issue #2615. + + When called from within a running event loop (e.g. from a LangGraph + node), the blocking sync call is offloaded to a thread pool so the + caller's loop is not blocked. Args: messages: List of conversation messages. @@ -451,14 +469,34 @@ class MemoryUpdater: Returns: True if update was successful, False otherwise. """ - return _run_async_update_sync( - self.aupdate_memory( - messages=messages, - thread_id=thread_id, - agent_name=agent_name, - correction_detected=correction_detected, - reinforcement_detected=reinforcement_detected, - ) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None and loop.is_running(): + try: + future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit( + self._do_update_memory_sync, + messages=messages, + thread_id=thread_id, + agent_name=agent_name, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, + user_id=user_id, + ) + return future.result() + except Exception: + logger.exception("Failed to offload memory update to executor") + return False + + return self._do_update_memory_sync( + messages=messages, + thread_id=thread_id, + agent_name=agent_name, + correction_detected=correction_detected, + reinforcement_detected=reinforcement_detected, + user_id=user_id, ) def _apply_updates( diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index b4fb87a52..03d135564 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -1,13 +1,10 @@ import asyncio from unittest.mock import AsyncMock, MagicMock, patch -import pytest - from deerflow.agents.memory.prompt import format_conversation_for_update from deerflow.agents.memory.updater import ( MemoryUpdater, _extract_text, - _run_async_update_sync, clear_memory_data, create_memory_fact, delete_memory_fact, @@ -528,6 +525,7 @@ class TestUpdateMemoryStructuredResponse: response = MagicMock() response.content = content model.ainvoke = AsyncMock(return_value=response) + model.invoke = MagicMock(return_value=response) return model def test_string_response_parses(self): @@ -551,7 +549,7 @@ class TestUpdateMemoryStructuredResponse: result = updater.update_memory([msg, ai_msg]) assert result is True - model.ainvoke.assert_awaited_once() + model.invoke.assert_called_once() def test_list_content_response_parses(self): """LLM response as list-of-blocks should be extracted, not repr'd.""" @@ -576,7 +574,8 @@ class TestUpdateMemoryStructuredResponse: assert result is True - def test_async_update_memory_uses_ainvoke(self): + def test_async_update_memory_delegates_to_sync(self): + """aupdate_memory should delegate to sync _do_update_memory_sync via to_thread.""" updater = MemoryUpdater() valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' model = self._make_mock_model(valid_json) @@ -597,8 +596,9 @@ class TestUpdateMemoryStructuredResponse: result = asyncio.run(updater.aupdate_memory([msg, ai_msg])) assert result is True - model.ainvoke.assert_awaited_once() - assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "memory_agent"} + # aupdate_memory delegates to sync path — model.invoke, not ainvoke + model.invoke.assert_called_once() + model.ainvoke.assert_not_called() def test_correction_hint_injected_when_detected(self): updater = MemoryUpdater() @@ -622,7 +622,7 @@ class TestUpdateMemoryStructuredResponse: result = updater.update_memory([msg, ai_msg], correction_detected=True) assert result is True - prompt = model.ainvoke.await_args.args[0] + prompt = model.invoke.call_args.args[0] assert "Explicit correction signals were detected" in prompt def test_correction_hint_empty_when_not_detected(self): @@ -647,7 +647,7 @@ class TestUpdateMemoryStructuredResponse: result = updater.update_memory([msg, ai_msg], correction_detected=False) assert result is True - prompt = model.ainvoke.await_args.args[0] + prompt = model.invoke.call_args.args[0] assert "Explicit correction signals were detected" not in prompt def test_sync_update_memory_wrapper_works_in_running_loop(self): @@ -675,9 +675,9 @@ class TestUpdateMemoryStructuredResponse: result = asyncio.run(run_in_loop()) assert result is True - model.ainvoke.assert_awaited_once() + model.invoke.assert_called_once() - def test_sync_update_memory_returns_false_when_bridge_submit_fails(self): + def test_sync_update_memory_returns_false_when_executor_down(self): updater = MemoryUpdater() with ( @@ -702,33 +702,67 @@ class TestUpdateMemoryStructuredResponse: assert result is False -class TestRunAsyncUpdateSync: - def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self): - class CloseableAwaitable: - def __init__(self): - self.closed = False +class TestSyncUpdateIsolatesProviderClientPool: + """Regression tests for issue #2615. - def __await__(self): - pytest.fail("awaitable should not have been awaited") - yield + The sync ``update_memory`` path must use ``model.invoke()`` (sync HTTP) + and never touch the async provider client pool shared with the lead agent. + """ - def close(self): - self.closed = True + def test_sync_update_uses_invoke_not_ainvoke(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = MagicMock() + response = MagicMock() + response.content = valid_json + model.invoke = MagicMock(return_value=response) + model.ainvoke = AsyncMock(return_value=response) - awaitable = CloseableAwaitable() - - with patch( - "deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit", - side_effect=RuntimeError("executor down"), + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), ): + msg = MagicMock() + msg.type = "human" + msg.content = "Hello" + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Hi" + ai_msg.tool_calls = [] + result = updater.update_memory([msg, ai_msg]) - async def run_in_loop(): - return _run_async_update_sync(awaitable) + assert result is True + model.invoke.assert_called_once() + model.ainvoke.assert_not_called() - result = asyncio.run(run_in_loop()) + def test_no_event_loop_created_during_sync_update(self): + """Sync update must not create or destroy any event loop.""" + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = MagicMock() + response = MagicMock() + response.content = valid_json + model.invoke = MagicMock(return_value=response) - assert result is False - assert awaitable.closed is True + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + patch("asyncio.run", side_effect=AssertionError("asyncio.run must not be called from sync update path")), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Hello" + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Hi" + ai_msg.tool_calls = [] + result = updater.update_memory([msg, ai_msg]) + + assert result is True class TestFactDeduplicationCaseInsensitive: @@ -805,6 +839,7 @@ class TestReinforcementHint: response = MagicMock() response.content = f"```json\n{json_response}\n```" model.ainvoke = AsyncMock(return_value=response) + model.invoke = MagicMock(return_value=response) return model def test_reinforcement_hint_injected_when_detected(self): @@ -829,7 +864,7 @@ class TestReinforcementHint: result = updater.update_memory([msg, ai_msg], reinforcement_detected=True) assert result is True - prompt = model.ainvoke.await_args.args[0] + prompt = model.invoke.call_args.args[0] assert "Positive reinforcement signals were detected" in prompt def test_reinforcement_hint_absent_when_not_detected(self): @@ -854,7 +889,7 @@ class TestReinforcementHint: result = updater.update_memory([msg, ai_msg], reinforcement_detected=False) assert result is True - prompt = model.ainvoke.await_args.args[0] + prompt = model.invoke.call_args.args[0] assert "Positive reinforcement signals were detected" not in prompt def test_both_hints_present_when_both_detected(self): @@ -879,7 +914,7 @@ class TestReinforcementHint: result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True) assert result is True - prompt = model.ainvoke.await_args.args[0] + prompt = model.invoke.call_args.args[0] assert "Explicit correction signals were detected" in prompt assert "Positive reinforcement signals were detected" in prompt @@ -908,11 +943,11 @@ class TestFinalizeCacheIsolation: ) mock_response = MagicMock() mock_response.content = new_fact_json - mock_model = AsyncMock() - mock_model.ainvoke = AsyncMock(return_value=mock_response) + mock_model = MagicMock() + mock_model.invoke = MagicMock(return_value=mock_response) saved_objects: list[dict] = [] - save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails + save_mock = MagicMock(side_effect=lambda m, a=None, **_: saved_objects.append(m) or False) # always fails with ( patch.object(updater, "_get_model", return_value=mock_model), @@ -929,6 +964,85 @@ class TestFinalizeCacheIsolation: ai_msg.tool_calls = [] updater.update_memory([msg, ai_msg], thread_id="t1") + # save_mock must have been exercised — otherwise the deepcopy-on-save-failure path isn't covered + save_mock.assert_called_once() + assert len(saved_objects) == 1, "save must have been called with the updated memory object" + # original_memory must not have been mutated — deepcopy isolates the mutation assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates" assert original_memory["facts"][0]["content"] == "original" + + +class TestUserIdForwarding: + """Regression: user_id must flow through the entire sync update path. + + When MemoryUpdateQueue captures context.user_id and passes it into + update_memory(..., user_id=context.user_id), the sync path must forward + it into _prepare_update_prompt → get_memory_data() and + _finalize_update → save(), so per-user memory isolation is maintained. + """ + + @staticmethod + def _make_mock_model(content): + model = MagicMock() + response = MagicMock() + response.content = content + model.invoke = MagicMock(return_value=response) + return model + + def test_sync_update_forwards_user_id_to_load_and_save(self): + """update_memory must pass user_id to get_memory_data and storage.save.""" + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + mock_storage = MagicMock() + mock_storage.save = MagicMock(return_value=True) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()) as mock_load, + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Hello" + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Hi" + ai_msg.tool_calls = [] + result = updater.update_memory([msg, ai_msg], user_id="user-42") + + assert result is True + mock_load.assert_called_once_with(None, user_id="user-42") + mock_storage.save.assert_called_once() + save_call = mock_storage.save.call_args + assert save_call.kwargs.get("user_id") == "user-42" or (len(save_call.args) > 2 and save_call.args[2] == "user-42") + + def test_async_update_forwards_user_id_to_load_and_save(self): + """aupdate_memory must pass user_id through to the sync delegate.""" + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + mock_storage = MagicMock() + mock_storage.save = MagicMock(return_value=True) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()) as mock_load, + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Hello" + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Hi" + ai_msg.tool_calls = [] + result = asyncio.run(updater.aupdate_memory([msg, ai_msg], user_id="user-99")) + + assert result is True + mock_load.assert_called_once_with(None, user_id="user-99") + save_call = mock_storage.save.call_args + assert save_call.kwargs.get("user_id") == "user-99" or (len(save_call.args) > 2 and save_call.args[2] == "user-99")