deer-flow/backend/tests/test_checkpointer.py
KiteEater 7caf03e97c
fix(packaging): add postgres extra for store/checkpointer supportFix postgres extra install guidance (#2584)
* Fix postgres extra install guidance

* Fix postgres install message lint

* Format postgres install messages

* Fix postgres install guidance and config docs
2026-05-09 09:49:08 +08:00

418 lines
17 KiB
Python

"""Unit tests for checkpointer config, packaging metadata, and factories."""
import sys
import tomllib
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import deerflow.config.app_config as app_config_module
from deerflow.config.checkpointer_config import (
CheckpointerConfig,
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
@pytest.fixture(autouse=True)
def reset_state():
"""Reset singleton state before each test."""
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
yield
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
class TestCheckpointerConfig:
def test_load_memory_config(self):
load_checkpointer_config_from_dict({"type": "memory"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "memory"
assert config.connection_string is None
def test_load_sqlite_config(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "sqlite"
assert config.connection_string == "/tmp/test.db"
def test_load_postgres_config(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "postgres"
assert config.connection_string == "postgresql://localhost/db"
def test_default_connection_string_is_none(self):
config = CheckpointerConfig(type="memory")
assert config.connection_string is None
def test_set_config_to_none(self):
load_checkpointer_config_from_dict({"type": "memory"})
set_checkpointer_config(None)
assert get_checkpointer_config() is None
def test_invalid_type_raises(self):
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
def test_connection_string_description_matches_runtime_defaults(self):
description = CheckpointerConfig.model_fields["connection_string"].description
assert description is not None
assert "Optional for sqlite" in description
assert "defaults to 'store.db'" in description
assert "Required for postgres" in description
class TestHarnessPackaging:
def test_pyproject_declares_postgres_extra(self):
pyproject_path = Path(__file__).resolve().parents[1] / "packages" / "harness" / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert "postgres" in optional_dependencies
assert optional_dependencies["postgres"] == [
"asyncpg>=0.29",
"langgraph-checkpoint-postgres>=3.0.5",
"psycopg[binary]>=3.3.3",
"psycopg-pool>=3.3.0",
]
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
assert "deerflow-harness[postgres]" in POSTGRES_STORE_INSTALL
assert "uv sync --all-packages --extra postgres" in POSTGRES_INSTALL
assert "uv sync --all-packages --extra postgres" in POSTGRES_STORE_INSTALL
# ---------------------------------------------------------------------------
# Factory tests
# ---------------------------------------------------------------------------
class TestGetCheckpointer:
def test_returns_in_memory_saver_when_not_configured(self):
"""get_checkpointer should return InMemorySaver when not configured."""
from langgraph.checkpoint.memory import InMemorySaver
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_memory_returns_in_memory_saver(self):
load_checkpointer_config_from_dict({"type": "memory"})
from langgraph.checkpoint.memory import InMemorySaver
cp = get_checkpointer()
assert isinstance(cp, InMemorySaver)
def test_memory_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
cp2 = get_checkpointer()
assert cp1 is cp2
def test_reset_clears_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
reset_checkpointer()
cp2 = get_checkpointer()
assert cp1 is not cp2
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
get_checkpointer()
def test_postgres_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
get_checkpointer()
def test_postgres_raises_when_connection_string_missing(self):
load_checkpointer_config_from_dict({"type": "postgres"})
mock_saver = MagicMock()
mock_module = MagicMock()
mock_module.PostgresSaver = mock_saver
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
reset_checkpointer()
with pytest.raises(ValueError, match="connection_string is required"):
get_checkpointer()
def test_sqlite_creates_saver(self):
"""SQLite checkpointer is created when package is available."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once()
mock_saver_instance.setup.assert_called_once()
def test_sqlite_creates_parent_dir(self):
"""Sync SQLite checkpointer should call ensure_sqlite_parent_dir before connecting.
This mirrors the async checkpointer's behaviour and prevents
'sqlite3.OperationalError: unable to open database file' when the
parent directory for the database file does not yet exist (e.g. when
using the harness package from an external virtualenv where the
.deer-flow directory has not been created).
"""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "relative/test.db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch("deerflow.runtime.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure,
patch(
"deerflow.runtime.checkpointer.provider.resolve_sqlite_conn_str",
return_value="/tmp/resolved/relative/test.db",
),
):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_ensure.assert_called_once_with("/tmp/resolved/relative/test.db")
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/relative/test.db")
def test_sqlite_ensure_parent_dir_before_connect(self):
"""ensure_sqlite_parent_dir must be called before from_conn_string."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "relative/test.db"})
call_order = []
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(side_effect=lambda *a, **kw: (call_order.append("connect"), mock_cm)[1])
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
def record_ensure(*a, **kw):
call_order.append("ensure")
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch(
"deerflow.runtime.checkpointer.provider.ensure_sqlite_parent_dir",
side_effect=record_ensure,
),
patch(
"deerflow.runtime.checkpointer.provider.resolve_sqlite_conn_str",
return_value="/tmp/resolved/relative/test.db",
),
):
reset_checkpointer()
get_checkpointer()
assert call_order == ["ensure", "connect"]
def test_postgres_creates_saver(self):
"""Postgres checkpointer is created when packages are available."""
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_pg_module = MagicMock()
mock_pg_module.PostgresSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
mock_saver_instance.setup.assert_called_once()
class TestAsyncCheckpointer:
@pytest.mark.anyio
async def test_sqlite_creates_parent_dir_via_to_thread(self):
"""Async SQLite setup should move mkdir off the event loop."""
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
mock_config = MagicMock()
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
mock_saver = AsyncMock()
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_saver
mock_cm.__aexit__.return_value = False
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string.return_value = mock_cm
mock_module = MagicMock()
mock_module.AsyncSqliteSaver = mock_saver_cls
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
patch(
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
return_value="/tmp/resolved/test.db",
),
):
async with make_checkpointer() as saver:
assert saver is mock_saver
mock_to_thread.assert_awaited_once()
called_fn, called_path = mock_to_thread.await_args.args
assert called_fn.__name__ == "ensure_sqlite_parent_dir"
assert called_path == "/tmp/resolved/test.db"
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
mock_saver.setup.assert_awaited_once()
# ---------------------------------------------------------------------------
# app_config.py integration
# ---------------------------------------------------------------------------
class TestAppConfigLoadsCheckpointer:
def test_load_checkpointer_section(self):
"""load_checkpointer_config_from_dict populates the global config."""
set_checkpointer_config(None)
load_checkpointer_config_from_dict({"type": "memory"})
cfg = get_checkpointer_config()
assert cfg is not None
assert cfg.type == "memory"
# ---------------------------------------------------------------------------
# DeerFlowClient falls back to config checkpointer
# ---------------------------------------------------------------------------
class TestClientCheckpointerFallback:
def test_client_uses_config_checkpointer_when_none_provided(self):
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=None)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert "checkpointer" in captured_kwargs
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
def test_client_explicit_checkpointer_takes_precedence(self):
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
explicit_cp = MagicMock()
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=explicit_cp)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert captured_kwargs["checkpointer"] is explicit_cp