Willem Jiang 031d6fbcbe
fix(checkpointer): use AsyncConnectionPool for postgres to prevent stale connection errors (#3223) (#3226)
* fix(checkpointer): use AsyncConnectionPool for postgres to prevent stale connection errors (#3223)

  Replace AsyncPostgresSaver.from_conn_string() with an explicit
  AsyncConnectionPool that has check_connection enabled, so dead idle
  connections are detected and replaced on checkout instead of raising
  OperationalError.

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Fixed the unit test error and lint error

* fix(checkpointer): add TCP keepalive to postgres connection pool (#3254)

  Enable TCP keepalive probes on the AsyncConnectionPool to prevent
  idle postgres connections from being dropped by the server or network
  middleware. Combined with the existing check_connection callback, this
  provides defense-in-depth against stale connection errors.

  Fixes #3254

* Changed the code as review suggestion

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-01 09:05:11 +08:00

203 lines
6.6 KiB
Python

"""Async checkpointer factory.
Provides an **async context manager** for long-running async servers that need
proper resource cleanup.
Supported backends: memory, sqlite, postgres.
Usage (e.g. FastAPI lifespan)::
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer # InMemorySaver if not configured
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
SQLITE_INSTALL,
)
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
def _prepare_sqlite_checkpointer_path(raw: str) -> str:
conn_str = resolve_sqlite_conn_str(raw)
ensure_sqlite_parent_dir(conn_str)
return conn_str
def _prepare_database_sqlite_checkpointer_path(db_config) -> str:
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
return conn_str
def _build_postgres_pool(conn_string: str):
"""Build an AsyncConnectionPool with TCP keepalive and connection checking."""
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
return AsyncConnectionPool(
conn_string,
kwargs={
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 6,
},
check=AsyncConnectionPool.check_connection,
)
def _ensure_postgres_imports():
"""Import and return (AsyncPostgresSaver, AsyncConnectionPool), raising ImportError on failure."""
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
try:
from psycopg_pool import AsyncConnectionPool
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
return AsyncPostgresSaver, AsyncConnectionPool
# ---------------------------------------------------------------------------
# Async factory
# ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs and tears down a checkpointer."""
if config.type == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
if config.type == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = await asyncio.to_thread(_prepare_sqlite_checkpointer_path, config.connection_string or "store.db")
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if config.type == "postgres":
if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED)
AsyncPostgresSaver, _ = _ensure_postgres_imports()
pool = _build_postgres_pool(config.connection_string)
async with pool:
saver = AsyncPostgresSaver(conn=pool)
await saver.setup()
yield saver
return
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
# ---------------------------------------------------------------------------
# Public async context manager
# ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs a checkpointer from unified DatabaseConfig."""
if db_config.backend == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
if db_config.backend == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = await asyncio.to_thread(_prepare_database_sqlite_checkpointer_path, db_config)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if db_config.backend == "postgres":
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
AsyncPostgresSaver, _ = _ensure_postgres_imports()
pool = _build_postgres_pool(db_config.postgres_url)
async with pool:
saver = AsyncPostgresSaver(conn=pool)
await saver.setup()
yield saver
return
raise ValueError(f"Unknown database backend: {db_config.backend!r}")
@contextlib.asynccontextmanager
async def make_checkpointer(app_config: AppConfig | None = None) -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state::
async with make_checkpointer(app_config) as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Priority:
1. Legacy ``checkpointer:`` config section (backward compatible)
2. Unified ``database:`` config section
3. Default InMemorySaver
"""
if app_config is None:
app_config = get_app_config()
# Legacy: standalone checkpointer config takes precedence
if app_config.checkpointer is not None:
async with _async_checkpointer(app_config.checkpointer) as saver:
yield saver
return
# Unified database config
db_config = getattr(app_config, "database", None)
if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver:
yield saver
return
# Default: in-memory
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()