diff --git a/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py b/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py index ac2d1da51..be4543492 100644 --- a/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py +++ b/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py @@ -47,6 +47,41 @@ def _prepare_database_sqlite_checkpointer_path(db_config) -> str: return conn_str +def _build_postgres_pool(conn_string: str): + """Build an AsyncConnectionPool with TCP keepalive and connection checking.""" + from psycopg.rows import dict_row + from psycopg_pool import AsyncConnectionPool + + return AsyncConnectionPool( + conn_string, + kwargs={ + "autocommit": True, + "prepare_threshold": 0, + "row_factory": dict_row, + "keepalives": 1, + "keepalives_idle": 60, + "keepalives_interval": 10, + "keepalives_count": 6, + }, + check=AsyncConnectionPool.check_connection, + ) + + +def _ensure_postgres_imports(): + """Import and return (AsyncPostgresSaver, AsyncConnectionPool), raising ImportError on failure.""" + try: + from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver + except ImportError as exc: + raise ImportError(POSTGRES_INSTALL) from exc + + try: + from psycopg_pool import AsyncConnectionPool + except ImportError as exc: + raise ImportError(POSTGRES_INSTALL) from exc + + return AsyncPostgresSaver, AsyncConnectionPool + + # --------------------------------------------------------------------------- # Async factory # --------------------------------------------------------------------------- @@ -74,15 +109,13 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]: return if config.type == "postgres": - try: - from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - except ImportError as exc: - raise ImportError(POSTGRES_INSTALL) from exc - if not config.connection_string: raise ValueError(POSTGRES_CONN_REQUIRED) - async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver: + AsyncPostgresSaver, _ = _ensure_postgres_imports() + pool = _build_postgres_pool(config.connection_string) + async with pool: + saver = AsyncPostgresSaver(conn=pool) await saver.setup() yield saver return @@ -117,15 +150,13 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi return if db_config.backend == "postgres": - try: - from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - except ImportError as exc: - raise ImportError(POSTGRES_INSTALL) from exc - if not db_config.postgres_url: raise ValueError("database.postgres_url is required for the postgres backend") - async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver: + AsyncPostgresSaver, _ = _ensure_postgres_imports() + pool = _build_postgres_pool(db_config.postgres_url) + async with pool: + saver = AsyncPostgresSaver(conn=pool) await saver.setup() yield saver return diff --git a/backend/tests/test_checkpointer.py b/backend/tests/test_checkpointer.py index 166282928..751d3a74e 100644 --- a/backend/tests/test_checkpointer.py +++ b/backend/tests/test_checkpointer.py @@ -326,6 +326,99 @@ class TestAsyncCheckpointer: mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db") mock_saver.setup.assert_awaited_once() + @pytest.mark.anyio + async def test_postgres_uses_connection_pool(self): + """Async postgres checkpointer should use AsyncConnectionPool, not a single connection.""" + from deerflow.runtime.checkpointer.async_provider import make_checkpointer + + mock_config = MagicMock() + mock_config.checkpointer = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db") + + mock_saver = AsyncMock() + + mock_saver_cls = MagicMock(return_value=mock_saver) + + mock_pool_instance = AsyncMock() + mock_pool_instance.__aenter__.return_value = mock_pool_instance + mock_pool_instance.__aexit__.return_value = False + + mock_pool_cls = MagicMock(return_value=mock_pool_instance) + mock_pool_cls.check_connection = AsyncMock() + mock_dict_row = MagicMock() + + mock_pg_module = MagicMock() + mock_pg_module.AsyncPostgresSaver = mock_saver_cls + + mock_psycopg_rows = MagicMock() + mock_psycopg_rows.dict_row = mock_dict_row + + with ( + patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config), + patch.dict(sys.modules, {"langgraph.checkpoint.postgres.aio": mock_pg_module}), + patch.dict(sys.modules, {"psycopg.rows": mock_psycopg_rows}), + patch.dict(sys.modules, {"psycopg_pool": MagicMock(AsyncConnectionPool=mock_pool_cls)}), + ): + # AsyncConnectionPool() is a callable that returns mock_pool_instance + # We need the constructor to be an async context manager + async with make_checkpointer() as saver: + assert saver is mock_saver + + # Verify the pool was constructed with check Connection + mock_pool_cls.assert_called_once() + call_kwargs = mock_pool_cls.call_args + assert call_kwargs[0][0] == "postgresql://localhost/db" + assert call_kwargs[1]["check"] is mock_pool_cls.check_connection + + # Verify saver was constructed with the pool (not via from_conn_string) + mock_saver_cls.assert_called_once_with(conn=mock_pool_instance) + mock_saver.setup.assert_awaited_once() + + @pytest.mark.anyio + async def test_database_postgres_uses_connection_pool(self): + """Unified database postgres path should use AsyncConnectionPool with keepalive.""" + from deerflow.config.database_config import DatabaseConfig + from deerflow.runtime.checkpointer.async_provider import make_checkpointer + + db_config = DatabaseConfig(backend="postgres", postgres_url="postgresql://localhost/db") + mock_config = MagicMock() + mock_config.checkpointer = None + mock_config.database = db_config + + mock_saver = AsyncMock() + + mock_saver_cls = MagicMock(return_value=mock_saver) + + mock_pool_instance = AsyncMock() + mock_pool_instance.__aenter__.return_value = mock_pool_instance + mock_pool_instance.__aexit__.return_value = False + + mock_pool_cls = MagicMock(return_value=mock_pool_instance) + mock_pool_cls.check_connection = AsyncMock() + mock_dict_row = MagicMock() + + mock_pg_module = MagicMock() + mock_pg_module.AsyncPostgresSaver = mock_saver_cls + + mock_psycopg_rows = MagicMock() + mock_psycopg_rows.dict_row = mock_dict_row + + with ( + patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config), + patch.dict(sys.modules, {"langgraph.checkpoint.postgres.aio": mock_pg_module}), + patch.dict(sys.modules, {"psycopg.rows": mock_psycopg_rows}), + patch.dict(sys.modules, {"psycopg_pool": MagicMock(AsyncConnectionPool=mock_pool_cls)}), + ): + async with make_checkpointer() as saver: + assert saver is mock_saver + + mock_pool_cls.assert_called_once() + call_kwargs = mock_pool_cls.call_args + assert call_kwargs[0][0] == "postgresql://localhost/db" + assert call_kwargs[1]["check"] is mock_pool_cls.check_connection + + mock_saver_cls.assert_called_once_with(conn=mock_pool_instance) + mock_saver.setup.assert_awaited_once() + @pytest.mark.anyio async def test_database_sqlite_creates_parent_dir_via_to_thread(self): """Unified database SQLite setup should also move path IO off the event loop."""