mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-09 17:12:01 +00:00
fix(runtime): protect sync singleton init and reset (#3413)
* fix(runtime): protect sync singleton init/reset with threading.Lock * fix(runtime): serialize sync singleton init and reset * make format * test(runtime): assert store reset creates new singleton * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(runtime): load config outside singleton locks * fix(runtime): share checkpointer config loading helper --------- Co-authored-by: GODDiao <diaoshengjia@gmail.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
3b4c9ff733
commit
f725a963d5
@ -41,6 +41,20 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
|
||||
_checkpointer_config = config
|
||||
|
||||
|
||||
def ensure_config_loaded() -> None:
|
||||
"""Lazily load app config when checkpointer config has not been initialized."""
|
||||
from deerflow.config.app_config import _app_config, get_app_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
if config is not None or _app_config is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
|
||||
"""Load checkpointer configuration from a dictionary."""
|
||||
global _checkpointer_config
|
||||
|
||||
@ -21,12 +21,13 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, ensure_config_loaded
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -100,6 +101,7 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||
|
||||
_checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
_checkpointer_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
@ -116,34 +118,29 @@ def get_checkpointer() -> Checkpointer:
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
# Ensure app config is loaded before checking checkpointer config
|
||||
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
|
||||
# but hasn't been loaded yet
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
# Config loading can reset both persistence singletons. Keep it outside
|
||||
# this provider lock to avoid cross-provider lock-order inversion.
|
||||
ensure_config_loaded()
|
||||
|
||||
config = get_checkpointer_config()
|
||||
with _checkpointer_lock:
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
if config is None and _app_config is None:
|
||||
# Only load app config lazily when neither the app config nor an explicit
|
||||
# checkpointer config has been initialized yet. This keeps tests that
|
||||
# intentionally set the global checkpointer config isolated from any
|
||||
# ambient config.yaml on disk.
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
# In test environments without config.yaml, this is expected.
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
_checkpointer = InMemorySaver()
|
||||
return _checkpointer
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
_checkpointer = _checkpointer_ctx.__enter__()
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
_checkpointer = InMemorySaver()
|
||||
return _checkpointer
|
||||
|
||||
checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
checkpointer = checkpointer_ctx.__enter__()
|
||||
_checkpointer_ctx = checkpointer_ctx
|
||||
_checkpointer = checkpointer
|
||||
|
||||
return _checkpointer
|
||||
|
||||
@ -155,13 +152,14 @@ def reset_checkpointer() -> None:
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
if _checkpointer_ctx is not None:
|
||||
try:
|
||||
_checkpointer_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during checkpointer cleanup", exc_info=True)
|
||||
_checkpointer_ctx = None
|
||||
_checkpointer = None
|
||||
with _checkpointer_lock:
|
||||
if _checkpointer_ctx is not None:
|
||||
try:
|
||||
_checkpointer_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during checkpointer cleanup", exc_info=True)
|
||||
_checkpointer_ctx = None
|
||||
_checkpointer = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -22,11 +22,13 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.checkpointer_config import ensure_config_loaded
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -100,6 +102,7 @@ def _sync_store_cm(config) -> Iterator[BaseStore]:
|
||||
|
||||
_store: BaseStore | None = None
|
||||
_store_ctx = None # open context manager keeping the connection alive
|
||||
_store_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_store() -> BaseStore:
|
||||
@ -117,29 +120,29 @@ def get_store() -> BaseStore:
|
||||
if _store is not None:
|
||||
return _store
|
||||
|
||||
# Lazily load app config, mirroring the checkpointer singleton pattern so
|
||||
# that tests that set the global checkpointer config explicitly remain isolated.
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
# Config loading can reset both persistence singletons. Keep it outside
|
||||
# this provider lock to avoid cross-provider lock-order inversion.
|
||||
ensure_config_loaded()
|
||||
|
||||
config = get_checkpointer_config()
|
||||
with _store_lock:
|
||||
if _store is not None:
|
||||
return _store
|
||||
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
if config is None and _app_config is None:
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
if config is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
_store = InMemoryStore()
|
||||
return _store
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
_store = InMemoryStore()
|
||||
return _store
|
||||
|
||||
_store_ctx = _sync_store_cm(config)
|
||||
_store = _store_ctx.__enter__()
|
||||
store_ctx = _sync_store_cm(config)
|
||||
store = store_ctx.__enter__()
|
||||
_store_ctx = store_ctx
|
||||
_store = store
|
||||
return _store
|
||||
|
||||
|
||||
@ -150,13 +153,14 @@ def reset_store() -> None:
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _store, _store_ctx
|
||||
if _store_ctx is not None:
|
||||
try:
|
||||
_store_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during store cleanup", exc_info=True)
|
||||
_store_ctx = None
|
||||
_store = None
|
||||
with _store_lock:
|
||||
if _store_ctx is not None:
|
||||
try:
|
||||
_store_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during store cleanup", exc_info=True)
|
||||
_store_ctx = None
|
||||
_store = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
|
||||
import sys
|
||||
import tomllib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from threading import Barrier, Event, Lock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -10,12 +12,14 @@ import pytest
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
ensure_config_loaded,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
|
||||
from deerflow.runtime.store import get_store, reset_store
|
||||
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
|
||||
|
||||
|
||||
@ -25,10 +29,90 @@ def reset_state():
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
yield
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
|
||||
|
||||
class _BlockingSingletonContext:
|
||||
def __init__(self, value: object, entered: Event, release: Event, stats: dict[str, object]):
|
||||
self._value = value
|
||||
self._entered = entered
|
||||
self._release = release
|
||||
self._stats = stats
|
||||
|
||||
def __enter__(self):
|
||||
with self._stats["lock"]:
|
||||
self._stats["enters"] += 1
|
||||
self._entered.set()
|
||||
assert self._release.wait(timeout=3), "timed out waiting to release singleton initialization"
|
||||
return self._value
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
with self._stats["lock"]:
|
||||
self._stats["exits"] += 1
|
||||
return False
|
||||
|
||||
|
||||
class _BlockingSingletonFactory:
|
||||
def __init__(self):
|
||||
self.value = object()
|
||||
self.entered = Event()
|
||||
self.release = Event()
|
||||
self.stats = {"enters": 0, "exits": 0, "lock": Lock()}
|
||||
|
||||
def context_manager(self, _config):
|
||||
return _BlockingSingletonContext(self.value, self.entered, self.release, self.stats)
|
||||
|
||||
def enter_count(self) -> int:
|
||||
with self.stats["lock"]:
|
||||
return self.stats["enters"]
|
||||
|
||||
def exit_count(self) -> int:
|
||||
with self.stats["lock"]:
|
||||
return self.stats["exits"]
|
||||
|
||||
|
||||
class _TrackingLock:
|
||||
def __init__(self):
|
||||
self._lock = Lock()
|
||||
self.acquired = Event()
|
||||
|
||||
def acquire(self, *args, **kwargs):
|
||||
acquired = self._lock.acquire(*args, **kwargs)
|
||||
if acquired:
|
||||
self.acquired.set()
|
||||
return acquired
|
||||
|
||||
def release(self):
|
||||
self._lock.release()
|
||||
|
||||
def __enter__(self):
|
||||
self.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.release()
|
||||
return False
|
||||
|
||||
def locked(self) -> bool:
|
||||
return self._lock.locked()
|
||||
|
||||
|
||||
def _call_getter_concurrently(getter, workers: int = 8) -> list[object]:
|
||||
ready = Barrier(workers + 1)
|
||||
|
||||
def worker():
|
||||
ready.wait(timeout=3)
|
||||
return getter()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
futures = [executor.submit(worker) for _ in range(workers)]
|
||||
ready.wait(timeout=3)
|
||||
return [future.result(timeout=3) for future in futures]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -67,6 +151,26 @@ class TestCheckpointerConfig:
|
||||
set_checkpointer_config(None)
|
||||
assert get_checkpointer_config() is None
|
||||
|
||||
def test_ensure_config_loaded_loads_app_config_when_uninitialized(self):
|
||||
def fake_get_app_config():
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
with patch("deerflow.config.app_config.get_app_config", side_effect=fake_get_app_config) as mock_get_app_config:
|
||||
ensure_config_loaded()
|
||||
|
||||
mock_get_app_config.assert_called_once()
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.type == "memory"
|
||||
|
||||
def test_ensure_config_loaded_skips_explicit_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
with patch("deerflow.config.app_config.get_app_config") as mock_get_app_config:
|
||||
ensure_config_loaded()
|
||||
|
||||
mock_get_app_config.assert_not_called()
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
load_checkpointer_config_from_dict({"type": "unknown"})
|
||||
@ -118,7 +222,7 @@ class TestGetCheckpointer:
|
||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
with patch("deerflow.config.app_config.get_app_config", side_effect=FileNotFoundError):
|
||||
cp = get_checkpointer()
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
@ -287,6 +391,143 @@ class TestGetCheckpointer:
|
||||
mock_saver_instance.setup.assert_called_once()
|
||||
|
||||
|
||||
class TestSyncSingletonThreadSafety:
|
||||
def test_store_reset_clears_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
store1 = get_store()
|
||||
reset_store()
|
||||
store2 = get_store()
|
||||
assert store1 is not store2
|
||||
|
||||
def test_concurrent_checkpointer_getter_creates_one_instance(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
factory = _BlockingSingletonFactory()
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager):
|
||||
futures_started = ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
result_future = futures_started.submit(_call_getter_concurrently, get_checkpointer)
|
||||
assert factory.entered.wait(timeout=3)
|
||||
factory.release.wait(timeout=0.05)
|
||||
factory.release.set()
|
||||
results = result_future.result(timeout=3)
|
||||
finally:
|
||||
futures_started.shutdown(wait=True)
|
||||
|
||||
assert all(result is factory.value for result in results)
|
||||
assert factory.enter_count() == 1
|
||||
|
||||
def test_concurrent_store_getter_creates_one_instance(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
factory = _BlockingSingletonFactory()
|
||||
|
||||
with patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager):
|
||||
futures_started = ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
result_future = futures_started.submit(_call_getter_concurrently, get_store)
|
||||
assert factory.entered.wait(timeout=3)
|
||||
factory.release.wait(timeout=0.05)
|
||||
factory.release.set()
|
||||
results = result_future.result(timeout=3)
|
||||
finally:
|
||||
futures_started.shutdown(wait=True)
|
||||
|
||||
assert all(result is factory.value for result in results)
|
||||
assert factory.enter_count() == 1
|
||||
|
||||
def test_checkpointer_loads_config_outside_singleton_lock(self):
|
||||
tracking_lock = _TrackingLock()
|
||||
|
||||
def fake_ensure_config_loaded():
|
||||
assert not tracking_lock.locked()
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.checkpointer.provider._checkpointer_lock", tracking_lock),
|
||||
patch("deerflow.runtime.checkpointer.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded),
|
||||
):
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
assert checkpointer is not None
|
||||
assert tracking_lock.acquired.is_set()
|
||||
|
||||
def test_store_loads_config_outside_singleton_lock(self):
|
||||
tracking_lock = _TrackingLock()
|
||||
|
||||
def fake_ensure_config_loaded():
|
||||
assert not tracking_lock.locked()
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.store.provider._store_lock", tracking_lock),
|
||||
patch("deerflow.runtime.store.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded),
|
||||
):
|
||||
store = get_store()
|
||||
|
||||
assert store is not None
|
||||
assert tracking_lock.acquired.is_set()
|
||||
|
||||
def test_checkpointer_reset_waits_for_initialization(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
factory = _BlockingSingletonFactory()
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager),
|
||||
ThreadPoolExecutor(max_workers=2) as executor,
|
||||
):
|
||||
get_future = executor.submit(get_checkpointer)
|
||||
assert factory.entered.wait(timeout=3)
|
||||
|
||||
reset_started = Event()
|
||||
|
||||
def reset_worker():
|
||||
reset_started.set()
|
||||
reset_checkpointer()
|
||||
|
||||
reset_future = executor.submit(reset_worker)
|
||||
assert reset_started.wait(timeout=3)
|
||||
factory.release.wait(timeout=0.05)
|
||||
|
||||
assert not reset_future.done()
|
||||
assert factory.exit_count() == 0
|
||||
|
||||
factory.release.set()
|
||||
assert get_future.result(timeout=3) is factory.value
|
||||
reset_future.result(timeout=3)
|
||||
|
||||
assert factory.exit_count() == 1
|
||||
|
||||
def test_store_reset_waits_for_initialization(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
factory = _BlockingSingletonFactory()
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager),
|
||||
ThreadPoolExecutor(max_workers=2) as executor,
|
||||
):
|
||||
get_future = executor.submit(get_store)
|
||||
assert factory.entered.wait(timeout=3)
|
||||
|
||||
reset_started = Event()
|
||||
|
||||
def reset_worker():
|
||||
reset_started.set()
|
||||
reset_store()
|
||||
|
||||
reset_future = executor.submit(reset_worker)
|
||||
assert reset_started.wait(timeout=3)
|
||||
factory.release.wait(timeout=0.05)
|
||||
|
||||
assert not reset_future.done()
|
||||
assert factory.exit_count() == 0
|
||||
|
||||
factory.release.set()
|
||||
assert get_future.result(timeout=3) is factory.value
|
||||
reset_future.result(timeout=3)
|
||||
|
||||
assert factory.exit_count() == 1
|
||||
|
||||
|
||||
class TestAsyncCheckpointer:
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user