fix(runtime): protect sync singleton init and reset (#3413)

* fix(runtime): protect sync singleton init/reset with threading.Lock

* fix(runtime): serialize sync singleton init and reset

* make format

* test(runtime): assert store reset creates new singleton

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix(runtime): load config outside singleton locks

* fix(runtime): share checkpointer config loading helper

---------

Co-authored-by: GODDiao <diaoshengjia@gmail.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Nan Gao 2026-06-08 02:38:36 +02:00 committed by GitHub
parent 3b4c9ff733
commit f725a963d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 313 additions and 56 deletions

View File

@ -41,6 +41,20 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
_checkpointer_config = config
def ensure_config_loaded() -> None:
"""Lazily load app config when checkpointer config has not been initialized."""
from deerflow.config.app_config import _app_config, get_app_config
config = get_checkpointer_config()
if config is not None or _app_config is not None:
return
try:
get_app_config()
except FileNotFoundError:
pass
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config

View File

@ -21,12 +21,13 @@ from __future__ import annotations
import contextlib
import logging
import threading
from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.checkpointer_config import CheckpointerConfig, ensure_config_loaded
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@ -100,6 +101,7 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
_checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
_checkpointer_lock = threading.Lock()
def get_checkpointer() -> Checkpointer:
@ -116,34 +118,29 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None:
return _checkpointer
# Ensure app config is loaded before checking checkpointer config
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
# Config loading can reset both persistence singletons. Keep it outside
# this provider lock to avoid cross-provider lock-order inversion.
ensure_config_loaded()
config = get_checkpointer_config()
with _checkpointer_lock:
if _checkpointer is not None:
return _checkpointer
from deerflow.config.checkpointer_config import get_checkpointer_config
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
return _checkpointer
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
return _checkpointer
checkpointer_ctx = _sync_checkpointer_cm(config)
checkpointer = checkpointer_ctx.__enter__()
_checkpointer_ctx = checkpointer_ctx
_checkpointer = checkpointer
return _checkpointer
@ -155,13 +152,14 @@ def reset_checkpointer() -> None:
Useful in tests or after a configuration change.
"""
global _checkpointer, _checkpointer_ctx
if _checkpointer_ctx is not None:
try:
_checkpointer_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during checkpointer cleanup", exc_info=True)
_checkpointer_ctx = None
_checkpointer = None
with _checkpointer_lock:
if _checkpointer_ctx is not None:
try:
_checkpointer_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during checkpointer cleanup", exc_info=True)
_checkpointer_ctx = None
_checkpointer = None
# ---------------------------------------------------------------------------

View File

@ -22,11 +22,13 @@ from __future__ import annotations
import contextlib
import logging
import threading
from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import ensure_config_loaded
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@ -100,6 +102,7 @@ def _sync_store_cm(config) -> Iterator[BaseStore]:
_store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive
_store_lock = threading.Lock()
def get_store() -> BaseStore:
@ -117,29 +120,29 @@ def get_store() -> BaseStore:
if _store is not None:
return _store
# Lazily load app config, mirroring the checkpointer singleton pattern so
# that tests that set the global checkpointer config explicitly remain isolated.
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
# Config loading can reset both persistence singletons. Keep it outside
# this provider lock to avoid cross-provider lock-order inversion.
ensure_config_loaded()
config = get_checkpointer_config()
with _store_lock:
if _store is not None:
return _store
from deerflow.config.checkpointer_config import get_checkpointer_config
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config()
if config is None:
from langgraph.store.memory import InMemoryStore
if config is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
_store = InMemoryStore()
return _store
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
_store = InMemoryStore()
return _store
_store_ctx = _sync_store_cm(config)
_store = _store_ctx.__enter__()
store_ctx = _sync_store_cm(config)
store = store_ctx.__enter__()
_store_ctx = store_ctx
_store = store
return _store
@ -150,13 +153,14 @@ def reset_store() -> None:
Useful in tests or after a configuration change.
"""
global _store, _store_ctx
if _store_ctx is not None:
try:
_store_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during store cleanup", exc_info=True)
_store_ctx = None
_store = None
with _store_lock:
if _store_ctx is not None:
try:
_store_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during store cleanup", exc_info=True)
_store_ctx = None
_store = None
# ---------------------------------------------------------------------------

View File

@ -2,7 +2,9 @@
import sys
import tomllib
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from threading import Barrier, Event, Lock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -10,12 +12,14 @@ import pytest
import deerflow.config.app_config as app_config_module
from deerflow.config.checkpointer_config import (
CheckpointerConfig,
ensure_config_loaded,
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
from deerflow.runtime.store import get_store, reset_store
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
@ -25,10 +29,90 @@ def reset_state():
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
reset_store()
yield
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
reset_store()
class _BlockingSingletonContext:
def __init__(self, value: object, entered: Event, release: Event, stats: dict[str, object]):
self._value = value
self._entered = entered
self._release = release
self._stats = stats
def __enter__(self):
with self._stats["lock"]:
self._stats["enters"] += 1
self._entered.set()
assert self._release.wait(timeout=3), "timed out waiting to release singleton initialization"
return self._value
def __exit__(self, exc_type, exc, tb):
with self._stats["lock"]:
self._stats["exits"] += 1
return False
class _BlockingSingletonFactory:
def __init__(self):
self.value = object()
self.entered = Event()
self.release = Event()
self.stats = {"enters": 0, "exits": 0, "lock": Lock()}
def context_manager(self, _config):
return _BlockingSingletonContext(self.value, self.entered, self.release, self.stats)
def enter_count(self) -> int:
with self.stats["lock"]:
return self.stats["enters"]
def exit_count(self) -> int:
with self.stats["lock"]:
return self.stats["exits"]
class _TrackingLock:
def __init__(self):
self._lock = Lock()
self.acquired = Event()
def acquire(self, *args, **kwargs):
acquired = self._lock.acquire(*args, **kwargs)
if acquired:
self.acquired.set()
return acquired
def release(self):
self._lock.release()
def __enter__(self):
self.acquire()
return self
def __exit__(self, exc_type, exc, tb):
self.release()
return False
def locked(self) -> bool:
return self._lock.locked()
def _call_getter_concurrently(getter, workers: int = 8) -> list[object]:
ready = Barrier(workers + 1)
def worker():
ready.wait(timeout=3)
return getter()
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = [executor.submit(worker) for _ in range(workers)]
ready.wait(timeout=3)
return [future.result(timeout=3) for future in futures]
# ---------------------------------------------------------------------------
@ -67,6 +151,26 @@ class TestCheckpointerConfig:
set_checkpointer_config(None)
assert get_checkpointer_config() is None
def test_ensure_config_loaded_loads_app_config_when_uninitialized(self):
def fake_get_app_config():
load_checkpointer_config_from_dict({"type": "memory"})
with patch("deerflow.config.app_config.get_app_config", side_effect=fake_get_app_config) as mock_get_app_config:
ensure_config_loaded()
mock_get_app_config.assert_called_once()
config = get_checkpointer_config()
assert config is not None
assert config.type == "memory"
def test_ensure_config_loaded_skips_explicit_config(self):
load_checkpointer_config_from_dict({"type": "memory"})
with patch("deerflow.config.app_config.get_app_config") as mock_get_app_config:
ensure_config_loaded()
mock_get_app_config.assert_not_called()
def test_invalid_type_raises(self):
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
@ -118,7 +222,7 @@ class TestGetCheckpointer:
"""get_checkpointer should return InMemorySaver when not configured."""
from langgraph.checkpoint.memory import InMemorySaver
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
with patch("deerflow.config.app_config.get_app_config", side_effect=FileNotFoundError):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
@ -287,6 +391,143 @@ class TestGetCheckpointer:
mock_saver_instance.setup.assert_called_once()
class TestSyncSingletonThreadSafety:
def test_store_reset_clears_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
store1 = get_store()
reset_store()
store2 = get_store()
assert store1 is not store2
def test_concurrent_checkpointer_getter_creates_one_instance(self):
load_checkpointer_config_from_dict({"type": "memory"})
factory = _BlockingSingletonFactory()
with patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager):
futures_started = ThreadPoolExecutor(max_workers=1)
try:
result_future = futures_started.submit(_call_getter_concurrently, get_checkpointer)
assert factory.entered.wait(timeout=3)
factory.release.wait(timeout=0.05)
factory.release.set()
results = result_future.result(timeout=3)
finally:
futures_started.shutdown(wait=True)
assert all(result is factory.value for result in results)
assert factory.enter_count() == 1
def test_concurrent_store_getter_creates_one_instance(self):
load_checkpointer_config_from_dict({"type": "memory"})
factory = _BlockingSingletonFactory()
with patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager):
futures_started = ThreadPoolExecutor(max_workers=1)
try:
result_future = futures_started.submit(_call_getter_concurrently, get_store)
assert factory.entered.wait(timeout=3)
factory.release.wait(timeout=0.05)
factory.release.set()
results = result_future.result(timeout=3)
finally:
futures_started.shutdown(wait=True)
assert all(result is factory.value for result in results)
assert factory.enter_count() == 1
def test_checkpointer_loads_config_outside_singleton_lock(self):
tracking_lock = _TrackingLock()
def fake_ensure_config_loaded():
assert not tracking_lock.locked()
load_checkpointer_config_from_dict({"type": "memory"})
with (
patch("deerflow.runtime.checkpointer.provider._checkpointer_lock", tracking_lock),
patch("deerflow.runtime.checkpointer.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded),
):
checkpointer = get_checkpointer()
assert checkpointer is not None
assert tracking_lock.acquired.is_set()
def test_store_loads_config_outside_singleton_lock(self):
tracking_lock = _TrackingLock()
def fake_ensure_config_loaded():
assert not tracking_lock.locked()
load_checkpointer_config_from_dict({"type": "memory"})
with (
patch("deerflow.runtime.store.provider._store_lock", tracking_lock),
patch("deerflow.runtime.store.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded),
):
store = get_store()
assert store is not None
assert tracking_lock.acquired.is_set()
def test_checkpointer_reset_waits_for_initialization(self):
load_checkpointer_config_from_dict({"type": "memory"})
factory = _BlockingSingletonFactory()
with (
patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager),
ThreadPoolExecutor(max_workers=2) as executor,
):
get_future = executor.submit(get_checkpointer)
assert factory.entered.wait(timeout=3)
reset_started = Event()
def reset_worker():
reset_started.set()
reset_checkpointer()
reset_future = executor.submit(reset_worker)
assert reset_started.wait(timeout=3)
factory.release.wait(timeout=0.05)
assert not reset_future.done()
assert factory.exit_count() == 0
factory.release.set()
assert get_future.result(timeout=3) is factory.value
reset_future.result(timeout=3)
assert factory.exit_count() == 1
def test_store_reset_waits_for_initialization(self):
load_checkpointer_config_from_dict({"type": "memory"})
factory = _BlockingSingletonFactory()
with (
patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager),
ThreadPoolExecutor(max_workers=2) as executor,
):
get_future = executor.submit(get_store)
assert factory.entered.wait(timeout=3)
reset_started = Event()
def reset_worker():
reset_started.set()
reset_store()
reset_future = executor.submit(reset_worker)
assert reset_started.wait(timeout=3)
factory.release.wait(timeout=0.05)
assert not reset_future.done()
assert factory.exit_count() == 0
factory.release.set()
assert get_future.result(timeout=3) is factory.value
reset_future.result(timeout=3)
assert factory.exit_count() == 1
class TestAsyncCheckpointer:
@pytest.mark.anyio
async def test_sqlite_creates_parent_dir_via_to_thread(self):