From f725a963d5b6a95a9e01a9a44c23009d8325c8ad Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Mon, 8 Jun 2026 02:38:36 +0200 Subject: [PATCH] fix(runtime): protect sync singleton init and reset (#3413) * fix(runtime): protect sync singleton init/reset with threading.Lock * fix(runtime): serialize sync singleton init and reset * make format * test(runtime): assert store reset creates new singleton * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(runtime): load config outside singleton locks * fix(runtime): share checkpointer config loading helper --------- Co-authored-by: GODDiao Co-authored-by: Willem Jiang Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../deerflow/config/checkpointer_config.py | 14 + .../deerflow/runtime/checkpointer/provider.py | 60 +++-- .../deerflow/runtime/store/provider.py | 52 ++-- backend/tests/test_checkpointer.py | 243 +++++++++++++++++- 4 files changed, 313 insertions(+), 56 deletions(-) diff --git a/backend/packages/harness/deerflow/config/checkpointer_config.py b/backend/packages/harness/deerflow/config/checkpointer_config.py index 963c439c1..ce10f3434 100644 --- a/backend/packages/harness/deerflow/config/checkpointer_config.py +++ b/backend/packages/harness/deerflow/config/checkpointer_config.py @@ -41,6 +41,20 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None: _checkpointer_config = config +def ensure_config_loaded() -> None: + """Lazily load app config when checkpointer config has not been initialized.""" + from deerflow.config.app_config import _app_config, get_app_config + + config = get_checkpointer_config() + if config is not None or _app_config is not None: + return + + try: + get_app_config() + except FileNotFoundError: + pass + + def load_checkpointer_config_from_dict(config_dict: dict | None) -> None: """Load checkpointer configuration from a dictionary.""" global _checkpointer_config diff --git a/backend/packages/harness/deerflow/runtime/checkpointer/provider.py b/backend/packages/harness/deerflow/runtime/checkpointer/provider.py index 39a4f272e..226454eb3 100644 --- a/backend/packages/harness/deerflow/runtime/checkpointer/provider.py +++ b/backend/packages/harness/deerflow/runtime/checkpointer/provider.py @@ -21,12 +21,13 @@ from __future__ import annotations import contextlib import logging +import threading from collections.abc import Iterator from langgraph.types import Checkpointer from deerflow.config.app_config import get_app_config -from deerflow.config.checkpointer_config import CheckpointerConfig +from deerflow.config.checkpointer_config import CheckpointerConfig, ensure_config_loaded from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str logger = logging.getLogger(__name__) @@ -100,6 +101,7 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]: _checkpointer: Checkpointer | None = None _checkpointer_ctx = None # open context manager keeping the connection alive +_checkpointer_lock = threading.Lock() def get_checkpointer() -> Checkpointer: @@ -116,34 +118,29 @@ def get_checkpointer() -> Checkpointer: if _checkpointer is not None: return _checkpointer - # Ensure app config is loaded before checking checkpointer config - # This prevents returning InMemorySaver when config.yaml actually has a checkpointer section - # but hasn't been loaded yet - from deerflow.config.app_config import _app_config - from deerflow.config.checkpointer_config import get_checkpointer_config + # Config loading can reset both persistence singletons. Keep it outside + # this provider lock to avoid cross-provider lock-order inversion. + ensure_config_loaded() - config = get_checkpointer_config() + with _checkpointer_lock: + if _checkpointer is not None: + return _checkpointer + + from deerflow.config.checkpointer_config import get_checkpointer_config - if config is None and _app_config is None: - # Only load app config lazily when neither the app config nor an explicit - # checkpointer config has been initialized yet. This keeps tests that - # intentionally set the global checkpointer config isolated from any - # ambient config.yaml on disk. - try: - get_app_config() - except FileNotFoundError: - # In test environments without config.yaml, this is expected. - pass config = get_checkpointer_config() - if config is None: - from langgraph.checkpoint.memory import InMemorySaver - logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)") - _checkpointer = InMemorySaver() - return _checkpointer + if config is None: + from langgraph.checkpoint.memory import InMemorySaver - _checkpointer_ctx = _sync_checkpointer_cm(config) - _checkpointer = _checkpointer_ctx.__enter__() + logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)") + _checkpointer = InMemorySaver() + return _checkpointer + + checkpointer_ctx = _sync_checkpointer_cm(config) + checkpointer = checkpointer_ctx.__enter__() + _checkpointer_ctx = checkpointer_ctx + _checkpointer = checkpointer return _checkpointer @@ -155,13 +152,14 @@ def reset_checkpointer() -> None: Useful in tests or after a configuration change. """ global _checkpointer, _checkpointer_ctx - if _checkpointer_ctx is not None: - try: - _checkpointer_ctx.__exit__(None, None, None) - except Exception: - logger.warning("Error during checkpointer cleanup", exc_info=True) - _checkpointer_ctx = None - _checkpointer = None + with _checkpointer_lock: + if _checkpointer_ctx is not None: + try: + _checkpointer_ctx.__exit__(None, None, None) + except Exception: + logger.warning("Error during checkpointer cleanup", exc_info=True) + _checkpointer_ctx = None + _checkpointer = None # --------------------------------------------------------------------------- diff --git a/backend/packages/harness/deerflow/runtime/store/provider.py b/backend/packages/harness/deerflow/runtime/store/provider.py index ecf597fe3..7e2ae563d 100644 --- a/backend/packages/harness/deerflow/runtime/store/provider.py +++ b/backend/packages/harness/deerflow/runtime/store/provider.py @@ -22,11 +22,13 @@ from __future__ import annotations import contextlib import logging +import threading from collections.abc import Iterator from langgraph.store.base import BaseStore from deerflow.config.app_config import get_app_config +from deerflow.config.checkpointer_config import ensure_config_loaded from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str logger = logging.getLogger(__name__) @@ -100,6 +102,7 @@ def _sync_store_cm(config) -> Iterator[BaseStore]: _store: BaseStore | None = None _store_ctx = None # open context manager keeping the connection alive +_store_lock = threading.Lock() def get_store() -> BaseStore: @@ -117,29 +120,29 @@ def get_store() -> BaseStore: if _store is not None: return _store - # Lazily load app config, mirroring the checkpointer singleton pattern so - # that tests that set the global checkpointer config explicitly remain isolated. - from deerflow.config.app_config import _app_config - from deerflow.config.checkpointer_config import get_checkpointer_config + # Config loading can reset both persistence singletons. Keep it outside + # this provider lock to avoid cross-provider lock-order inversion. + ensure_config_loaded() - config = get_checkpointer_config() + with _store_lock: + if _store is not None: + return _store + + from deerflow.config.checkpointer_config import get_checkpointer_config - if config is None and _app_config is None: - try: - get_app_config() - except FileNotFoundError: - pass config = get_checkpointer_config() - if config is None: - from langgraph.store.memory import InMemoryStore + if config is None: + from langgraph.store.memory import InMemoryStore - logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") - _store = InMemoryStore() - return _store + logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") + _store = InMemoryStore() + return _store - _store_ctx = _sync_store_cm(config) - _store = _store_ctx.__enter__() + store_ctx = _sync_store_cm(config) + store = store_ctx.__enter__() + _store_ctx = store_ctx + _store = store return _store @@ -150,13 +153,14 @@ def reset_store() -> None: Useful in tests or after a configuration change. """ global _store, _store_ctx - if _store_ctx is not None: - try: - _store_ctx.__exit__(None, None, None) - except Exception: - logger.warning("Error during store cleanup", exc_info=True) - _store_ctx = None - _store = None + with _store_lock: + if _store_ctx is not None: + try: + _store_ctx.__exit__(None, None, None) + except Exception: + logger.warning("Error during store cleanup", exc_info=True) + _store_ctx = None + _store = None # --------------------------------------------------------------------------- diff --git a/backend/tests/test_checkpointer.py b/backend/tests/test_checkpointer.py index 751d3a74e..e86be8644 100644 --- a/backend/tests/test_checkpointer.py +++ b/backend/tests/test_checkpointer.py @@ -2,7 +2,9 @@ import sys import tomllib +from concurrent.futures import ThreadPoolExecutor from pathlib import Path +from threading import Barrier, Event, Lock from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -10,12 +12,14 @@ import pytest import deerflow.config.app_config as app_config_module from deerflow.config.checkpointer_config import ( CheckpointerConfig, + ensure_config_loaded, get_checkpointer_config, load_checkpointer_config_from_dict, set_checkpointer_config, ) from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL +from deerflow.runtime.store import get_store, reset_store from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL @@ -25,10 +29,90 @@ def reset_state(): app_config_module._app_config = None set_checkpointer_config(None) reset_checkpointer() + reset_store() yield app_config_module._app_config = None set_checkpointer_config(None) reset_checkpointer() + reset_store() + + +class _BlockingSingletonContext: + def __init__(self, value: object, entered: Event, release: Event, stats: dict[str, object]): + self._value = value + self._entered = entered + self._release = release + self._stats = stats + + def __enter__(self): + with self._stats["lock"]: + self._stats["enters"] += 1 + self._entered.set() + assert self._release.wait(timeout=3), "timed out waiting to release singleton initialization" + return self._value + + def __exit__(self, exc_type, exc, tb): + with self._stats["lock"]: + self._stats["exits"] += 1 + return False + + +class _BlockingSingletonFactory: + def __init__(self): + self.value = object() + self.entered = Event() + self.release = Event() + self.stats = {"enters": 0, "exits": 0, "lock": Lock()} + + def context_manager(self, _config): + return _BlockingSingletonContext(self.value, self.entered, self.release, self.stats) + + def enter_count(self) -> int: + with self.stats["lock"]: + return self.stats["enters"] + + def exit_count(self) -> int: + with self.stats["lock"]: + return self.stats["exits"] + + +class _TrackingLock: + def __init__(self): + self._lock = Lock() + self.acquired = Event() + + def acquire(self, *args, **kwargs): + acquired = self._lock.acquire(*args, **kwargs) + if acquired: + self.acquired.set() + return acquired + + def release(self): + self._lock.release() + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc, tb): + self.release() + return False + + def locked(self) -> bool: + return self._lock.locked() + + +def _call_getter_concurrently(getter, workers: int = 8) -> list[object]: + ready = Barrier(workers + 1) + + def worker(): + ready.wait(timeout=3) + return getter() + + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [executor.submit(worker) for _ in range(workers)] + ready.wait(timeout=3) + return [future.result(timeout=3) for future in futures] # --------------------------------------------------------------------------- @@ -67,6 +151,26 @@ class TestCheckpointerConfig: set_checkpointer_config(None) assert get_checkpointer_config() is None + def test_ensure_config_loaded_loads_app_config_when_uninitialized(self): + def fake_get_app_config(): + load_checkpointer_config_from_dict({"type": "memory"}) + + with patch("deerflow.config.app_config.get_app_config", side_effect=fake_get_app_config) as mock_get_app_config: + ensure_config_loaded() + + mock_get_app_config.assert_called_once() + config = get_checkpointer_config() + assert config is not None + assert config.type == "memory" + + def test_ensure_config_loaded_skips_explicit_config(self): + load_checkpointer_config_from_dict({"type": "memory"}) + + with patch("deerflow.config.app_config.get_app_config") as mock_get_app_config: + ensure_config_loaded() + + mock_get_app_config.assert_not_called() + def test_invalid_type_raises(self): with pytest.raises(Exception): load_checkpointer_config_from_dict({"type": "unknown"}) @@ -118,7 +222,7 @@ class TestGetCheckpointer: """get_checkpointer should return InMemorySaver when not configured.""" from langgraph.checkpoint.memory import InMemorySaver - with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError): + with patch("deerflow.config.app_config.get_app_config", side_effect=FileNotFoundError): cp = get_checkpointer() assert cp is not None assert isinstance(cp, InMemorySaver) @@ -287,6 +391,143 @@ class TestGetCheckpointer: mock_saver_instance.setup.assert_called_once() +class TestSyncSingletonThreadSafety: + def test_store_reset_clears_singleton(self): + load_checkpointer_config_from_dict({"type": "memory"}) + store1 = get_store() + reset_store() + store2 = get_store() + assert store1 is not store2 + + def test_concurrent_checkpointer_getter_creates_one_instance(self): + load_checkpointer_config_from_dict({"type": "memory"}) + factory = _BlockingSingletonFactory() + + with patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager): + futures_started = ThreadPoolExecutor(max_workers=1) + try: + result_future = futures_started.submit(_call_getter_concurrently, get_checkpointer) + assert factory.entered.wait(timeout=3) + factory.release.wait(timeout=0.05) + factory.release.set() + results = result_future.result(timeout=3) + finally: + futures_started.shutdown(wait=True) + + assert all(result is factory.value for result in results) + assert factory.enter_count() == 1 + + def test_concurrent_store_getter_creates_one_instance(self): + load_checkpointer_config_from_dict({"type": "memory"}) + factory = _BlockingSingletonFactory() + + with patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager): + futures_started = ThreadPoolExecutor(max_workers=1) + try: + result_future = futures_started.submit(_call_getter_concurrently, get_store) + assert factory.entered.wait(timeout=3) + factory.release.wait(timeout=0.05) + factory.release.set() + results = result_future.result(timeout=3) + finally: + futures_started.shutdown(wait=True) + + assert all(result is factory.value for result in results) + assert factory.enter_count() == 1 + + def test_checkpointer_loads_config_outside_singleton_lock(self): + tracking_lock = _TrackingLock() + + def fake_ensure_config_loaded(): + assert not tracking_lock.locked() + load_checkpointer_config_from_dict({"type": "memory"}) + + with ( + patch("deerflow.runtime.checkpointer.provider._checkpointer_lock", tracking_lock), + patch("deerflow.runtime.checkpointer.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded), + ): + checkpointer = get_checkpointer() + + assert checkpointer is not None + assert tracking_lock.acquired.is_set() + + def test_store_loads_config_outside_singleton_lock(self): + tracking_lock = _TrackingLock() + + def fake_ensure_config_loaded(): + assert not tracking_lock.locked() + load_checkpointer_config_from_dict({"type": "memory"}) + + with ( + patch("deerflow.runtime.store.provider._store_lock", tracking_lock), + patch("deerflow.runtime.store.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded), + ): + store = get_store() + + assert store is not None + assert tracking_lock.acquired.is_set() + + def test_checkpointer_reset_waits_for_initialization(self): + load_checkpointer_config_from_dict({"type": "memory"}) + factory = _BlockingSingletonFactory() + + with ( + patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager), + ThreadPoolExecutor(max_workers=2) as executor, + ): + get_future = executor.submit(get_checkpointer) + assert factory.entered.wait(timeout=3) + + reset_started = Event() + + def reset_worker(): + reset_started.set() + reset_checkpointer() + + reset_future = executor.submit(reset_worker) + assert reset_started.wait(timeout=3) + factory.release.wait(timeout=0.05) + + assert not reset_future.done() + assert factory.exit_count() == 0 + + factory.release.set() + assert get_future.result(timeout=3) is factory.value + reset_future.result(timeout=3) + + assert factory.exit_count() == 1 + + def test_store_reset_waits_for_initialization(self): + load_checkpointer_config_from_dict({"type": "memory"}) + factory = _BlockingSingletonFactory() + + with ( + patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager), + ThreadPoolExecutor(max_workers=2) as executor, + ): + get_future = executor.submit(get_store) + assert factory.entered.wait(timeout=3) + + reset_started = Event() + + def reset_worker(): + reset_started.set() + reset_store() + + reset_future = executor.submit(reset_worker) + assert reset_started.wait(timeout=3) + factory.release.wait(timeout=0.05) + + assert not reset_future.done() + assert factory.exit_count() == 0 + + factory.release.set() + assert get_future.result(timeout=3) is factory.value + reset_future.result(timeout=3) + + assert factory.exit_count() == 1 + + class TestAsyncCheckpointer: @pytest.mark.anyio async def test_sqlite_creates_parent_dir_via_to_thread(self):