fix(gateway): drain in-flight runs before closing checkpointer on shutdown (#3381)

* fix(gateway): drain in-flight runs before closing checkpointer on shutdown

Chat runs execute in fire-and-forget background asyncio tasks that write
checkpoints through a shared checkpointer. On shutdown, langgraph_runtime's
AsyncExitStack tore down the checkpointer's postgres connection pool while
those run tasks were still mid-graph. langgraph's
AsyncPregelLoop._checkpointer_put_after_previous then ran its
`finally: await checkpointer.aput(...)` against the closed pool, raising
psycopg_pool.PoolClosed. Because that put runs in a langgraph-internal task
(not on run_agent's call stack), run_agent's try/except cannot catch it and it
surfaces as "unhandled exception during asyncio.run() shutdown".

Add RunManager.shutdown() to cancel and bounded-await all in-flight runs, and
call it from langgraph_runtime BEFORE the AsyncExitStack closes the
checkpointer, so the final checkpoint write lands while the pool is still open.
The drain is bounded by a timeout so a stuck run cannot hang worker shutdown,
and is shielded so a second shutdown signal cannot abandon it mid-drain and
reopen the race.

Closes #3373

* fix(gateway): address review — preserve completed-run status, bound drain persistence

Addresses Copilot review on #3381:

- RunManager.shutdown(): decide run status AFTER the drain. Under the lock it
  now only requests cancellation; after asyncio.wait it marks/persists
  `interrupted` only for runs still pending or ended cancelled. A run that
  completes (e.g. `success`) during the drain window keeps its real terminal
  status instead of being unconditionally overwritten.
- Bound the trailing status persistence within the timeout budget
  (deadline = loop.time()+timeout; gather wrapped in asyncio.wait_for) so a slow
  store backing off under DB pressure cannot push shutdown past the deadline.
- deps: use asyncio.create_task instead of asyncio.ensure_future.
- tests: wait deterministically for the run to be in-flight (poll the first
  checkpoint) instead of a fixed sleep; init shutdown_calls explicitly in the
  recovery test double; add regression test asserting a run completing during
  the drain keeps its status (in memory and in the store).

* fix(gateway): address maintainer review — surface failed drain persists, clarify timeout constant

Addresses @WillemJiang review on #3381:

- shutdown(): inspect the gather result of the trailing interrupted-status
  persistence. _persist_status is best-effort (it catches + logs its own
  failure with exc_info and returns False, so it never raises out of the
  gather), but the aggregate result was never checked — a partial failure had
  no shutdown-level visibility. Now any escaped Exception is logged, and any
  False (a persist that did not confirm) is logged with the run_id. Added
  regression test test_shutdown_surfaces_failed_interrupted_persist.
- deps: clarify the _RUN_DRAIN_TIMEOUT_SECONDS comment — state the actual value
  of _SHUTDOWN_HOOK_TIMEOUT_SECONDS (5.0s) and that both count toward the
  lifespan shutdown window. Kept as two separate constants (independent teardown
  steps that may diverge) rather than one shared "must match" value.
- Verified no other test fake needs the shutdown stub: _FakeRunManager in
  test_worker_langfuse_metadata.py is a run_agent() argument (worker path),
  never injected into langgraph_runtime, so it never receives shutdown().
This commit is contained in:
Xinmin Zeng 2026-06-07 11:24:30 +08:00 committed by GitHub
parent 9a5de8d6a5
commit 268fdd6968
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 497 additions and 0 deletions

View File

@ -17,6 +17,7 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
@ -33,6 +34,43 @@ from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
# Upper bound (seconds) for draining in-flight runs during shutdown, before the
# AsyncExitStack tears down the checkpointer (and its connection pool). Kept
# local to avoid an app -> deps -> app import cycle. This is a *separate* budget
# from ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS`` (currently also 5.0s,
# which bounds channel-service stop): the two govern independent teardown steps
# and may diverge, but both count toward the lifespan shutdown window — revisit
# them together if their sum must stay within the server's graceful-shutdown
# timeout.
_RUN_DRAIN_TIMEOUT_SECONDS = 5.0
async def _drain_inflight_runs(run_manager: RunManager) -> None:
"""Drain in-flight runs before the checkpointer is torn down (issue #3373).
Shields the (internally-bounded) drain so that even if the lifespan
coroutine is itself cancelled mid-shutdown a second SIGINT or the server's
graceful-shutdown timeout, i.e. the same signal storm behind #3373 — the
checkpointer pool is not closed while run tasks are still writing
checkpoints. On such a cancellation we let the already-running drain finish
(it is bounded by ``RunManager.shutdown``'s own timeout) and then propagate
the cancellation.
"""
drain = asyncio.create_task(run_manager.shutdown(timeout=_RUN_DRAIN_TIMEOUT_SECONDS))
try:
await asyncio.shield(drain)
except asyncio.CancelledError:
# Re-shield so this second wait does not abandon the in-flight drain;
# it is bounded, so this cannot hang. Then re-raise to honour shutdown.
try:
await asyncio.shield(drain)
except Exception:
logger.exception("In-flight run drain failed after shutdown cancellation")
raise
except Exception:
logger.exception("Failed to drain in-flight runs during shutdown")
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
@ -177,6 +215,14 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen
try:
yield
finally:
# Drain in-flight run tasks BEFORE the AsyncExitStack tears down the
# checkpointer (and its connection pool). A run still mid-graph would
# otherwise leak into asyncio.run() shutdown, where langgraph's
# _checkpointer_put_after_previous aput races the closed pool and
# raises PoolClosed (issue #3373).
run_manager = getattr(app.state, "run_manager", None)
if run_manager is not None:
await _drain_inflight_runs(run_manager)
await close_engine()

View File

@ -645,6 +645,98 @@ class RunManager:
self._runs.pop(run_id, None)
logger.debug("Run record %s cleaned up", run_id)
async def shutdown(self, *, timeout: float = 5.0) -> None:
"""Cancel and bounded-await all in-flight runs on process shutdown.
Chat runs execute in fire-and-forget background ``asyncio`` tasks that
write checkpoints through a shared checkpointer. On shutdown the
checkpointer's resources (e.g. the postgres connection pool owned by the
gateway's ``AsyncExitStack``) are torn down; if a run task is still
mid-graph at that point, langgraph's
``AsyncPregelLoop._checkpointer_put_after_previous`` runs its
``finally: await checkpointer.aput(...)`` against the closed pool. Because
that put runs in a langgraph-internal task (not on ``run_agent``'s call
stack), the resulting ``psycopg_pool.PoolClosed`` is not catchable by the
worker and surfaces as an unhandled exception during ``asyncio.run()``
shutdown (bytedance/deer-flow issue #3373).
Draining in-flight runs *before* the checkpointer is closed lets each
run that settles within ``timeout`` flush its final checkpoint while
resources are still open. Only runs that do **not** settle on their own
are marked ``interrupted`` a run that completes (e.g. ``success``)
during the drain keeps its real terminal status instead of being
blanket-overwritten. The whole drain, including the trailing status
persistence, is bounded by ``timeout`` so a run stuck in cleanup (or a
slow store under DB pressure) cannot hang worker shutdown the
precondition for the signal-reentrancy deadlock guarded by
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``. Runs still active
after ``timeout`` are logged and may still race teardown.
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
async with self._lock:
inflight = [record for record in self._runs.values() if record.status in (RunStatus.pending, RunStatus.running) and record.task is not None and not record.task.done()]
for record in inflight:
record.abort_action = "interrupt"
record.abort_event.set()
record.task.cancel() # type: ignore[union-attr] # filtered above
# Status is decided AFTER the drain (below), not here: a run that
# completes on its own during the drain must keep its real status.
if not inflight:
return
tasks = [record.task for record in inflight]
_, pending = await asyncio.wait(tasks, timeout=timeout)
# Only mark/persist ``interrupted`` for runs that did not settle on their
# own (still pending after the timeout, or ended cancelled). A run that
# finished normally during the drain keeps the status it set for itself.
to_persist: list[RunRecord] = []
async with self._lock:
for record in inflight:
task = record.task
if task not in pending and not task.cancelled():
# Completed on its own — retrieve any surfaced exception so it
# is not reported as "never retrieved", and keep its status.
task.exception() # type: ignore[union-attr] # done & not cancelled
continue
if record.status in (RunStatus.pending, RunStatus.running):
record.status = RunStatus.interrupted
record.updated_at = _now_iso()
to_persist.append(record)
# Bound the trailing status persistence within the remaining budget so a
# slow store (``_call_store_with_retry`` can back off under DB pressure)
# cannot push shutdown past ``timeout``.
if to_persist:
remaining = deadline - loop.time()
if remaining <= 0:
logger.warning("Run drain budget exhausted before persisting %d interrupted run(s) on shutdown", len(to_persist))
else:
try:
results = await asyncio.wait_for(
asyncio.gather(*(self._persist_status(record, RunStatus.interrupted) for record in to_persist), return_exceptions=True),
timeout=remaining,
)
except TimeoutError:
logger.warning("Run drain status persistence exceeded the %.1fs budget; %d record(s) may not be persisted", timeout, len(to_persist))
else:
# ``_persist_status`` is best-effort: it catches and logs its
# own failures, returning ``False``. Inspect the aggregate so a
# partial failure is surfaced at shutdown level (with the
# run_id) instead of being silently swallowed by the gather.
for record, result in zip(to_persist, results):
if isinstance(result, Exception):
logger.warning("Unexpected error persisting interrupted status for run %s during shutdown: %r", record.run_id, result)
elif result is False:
logger.warning("Could not persist interrupted status for run %s during shutdown", record.run_id)
if pending:
logger.warning("Run drain exceeded %.1fs on shutdown; %d run task(s) still active and may race checkpointer teardown", timeout, len(pending))
logger.info("Drained %d in-flight run(s) on shutdown (%d settled within %.1fs)", len(inflight), len(inflight) - len(pending), timeout)
class ConflictError(Exception):
"""Raised when multitask_strategy=reject and thread has inflight runs."""

View File

@ -0,0 +1,353 @@
"""Regression tests for graceful run-task drain on Gateway shutdown.
Guards bytedance/deer-flow issue #3373:
psycopg_pool.PoolClosed: the pool 'pool-1' is already closed
Root cause: chat runs are fire-and-forget background ``asyncio`` tasks
(``app/gateway/services.py`` -> ``asyncio.create_task(run_agent(...))``) owned
by nobody. On shutdown, ``langgraph_runtime``'s ``AsyncExitStack`` tore down the
checkpointer's postgres pool while those tasks were still mid-graph. langgraph's
``AsyncPregelLoop._checkpointer_put_after_previous`` then ran its
``finally: await checkpointer.aput(...)`` against the already-closed pool.
Fix: ``RunManager.shutdown()`` cancels and *bounded*-awaits every in-flight run,
and ``langgraph_runtime`` calls it BEFORE the ``AsyncExitStack`` closes the
checkpointer so the final checkpoint write lands while the pool is still open.
The drain must stay bounded (a stuck run must not hang the worker, the
precondition for the signal-reentrancy deadlock guarded by
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``).
"""
from __future__ import annotations
import asyncio
import operator
from contextlib import asynccontextmanager, suppress
from types import SimpleNamespace
from typing import Annotated, TypedDict
import pytest
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime import RunManager, RunStatus
# Module-level so langgraph's get_type_hints (which resolves annotations against
# module globals under `from __future__ import annotations`) can see Annotated.
class _CountState(TypedDict):
count: Annotated[int, operator.add]
class _CloseableSaver(InMemorySaver):
"""InMemorySaver that fails writes once closed, like a closed pool."""
def __init__(self) -> None:
super().__init__()
self._closed = False
self.writes_after_close: list[str] = []
def close(self) -> None:
self._closed = True
async def aput(self, *args, **kwargs):
if self._closed:
self.writes_after_close.append("aput")
raise RuntimeError("checkpointer is closed")
return await super().aput(*args, **kwargs)
async def aput_writes(self, *args, **kwargs):
if self._closed:
self.writes_after_close.append("aput_writes")
raise RuntimeError("checkpointer is closed")
return await super().aput_writes(*args, **kwargs)
@pytest.mark.asyncio
async def test_shutdown_cancels_and_awaits_inflight_run():
"""shutdown() cancels the in-flight task, waits for it, marks it interrupted."""
rm = RunManager()
record = await rm.create("t-drain")
await rm.set_status(record.run_id, RunStatus.running)
started = asyncio.Event()
cancelled = asyncio.Event()
async def worker() -> None:
try:
started.set()
await asyncio.Event().wait()
except asyncio.CancelledError:
cancelled.set()
raise
record.task = asyncio.create_task(worker())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
await rm.shutdown(timeout=5.0)
assert record.task.done()
assert cancelled.is_set()
assert record.status == RunStatus.interrupted
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_is_bounded_when_run_ignores_cancellation():
"""A run that swallows cancellation must not make shutdown() hang."""
rm = RunManager()
record = await rm.create("t-stubborn")
await rm.set_status(record.run_id, RunStatus.running)
started = asyncio.Event()
stop = asyncio.Event()
async def stubborn() -> None:
started.set()
while not stop.is_set():
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
if stop.is_set():
raise
# else: swallow — simulates a run stuck in slow cleanup
record.task = asyncio.create_task(stubborn())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
loop = asyncio.get_running_loop()
t0 = loop.time()
await rm.shutdown(timeout=0.3)
elapsed = loop.time() - t0
assert elapsed < 2.0, f"shutdown took {elapsed:.2f}s; drain is not bounded"
finally:
# cleanup the deliberately-stubborn task
stop.set()
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_is_noop_without_inflight_runs():
"""shutdown() on an idle manager completes cleanly and is idempotent."""
rm = RunManager()
await rm.shutdown(timeout=1.0)
# already-finished runs must not be re-cancelled or error out
record = await rm.create("t-done")
await rm.set_status(record.run_id, RunStatus.success)
await rm.shutdown(timeout=1.0)
@pytest.mark.asyncio
async def test_langgraph_runtime_drains_runs_before_closing_checkpointer(monkeypatch):
"""The wiring order lock for #3373: drain in-flight runs, THEN close the pool.
Patches every ``langgraph_runtime`` collaborator down to trivial stand-ins so
only the bootstrap/teardown ordering runs. The checkpointer probe records when
its context manager exits (pool close); a ``RunManager.shutdown`` spy records
when the drain happens. The drain MUST come first.
"""
from fastapi import FastAPI
from app.gateway.deps import langgraph_runtime
events: list[str] = []
@asynccontextmanager
async def probe_checkpointer(_config):
try:
yield object()
finally:
events.append("checkpointer_closed")
@asynccontextmanager
async def fake_stream_bridge(_config):
yield object()
@asynccontextmanager
async def fake_store(_config):
yield object()
async def fake_init_engine(_db):
return None
async def fake_close_engine():
return None
async def spy_shutdown(self, *, timeout): # noqa: ANN001
events.append("runs_drained")
monkeypatch.setattr("deerflow.runtime.checkpointer.async_provider.make_checkpointer", probe_checkpointer)
monkeypatch.setattr("deerflow.runtime.make_stream_bridge", fake_stream_bridge)
monkeypatch.setattr("deerflow.runtime.make_store", fake_store)
monkeypatch.setattr("deerflow.persistence.engine.init_engine_from_config", fake_init_engine)
monkeypatch.setattr("deerflow.persistence.engine.close_engine", fake_close_engine)
monkeypatch.setattr("deerflow.persistence.engine.get_session_factory", lambda: None)
monkeypatch.setattr("deerflow.runtime.events.store.make_run_event_store", lambda _cfg: object())
monkeypatch.setattr("deerflow.persistence.thread_meta.make_thread_store", lambda _sf, _store: object())
monkeypatch.setattr(RunManager, "shutdown", spy_shutdown, raising=False)
app = FastAPI()
startup_config = SimpleNamespace(database=SimpleNamespace(backend="memory"), run_events=None)
async with langgraph_runtime(app, startup_config):
pass
assert "runs_drained" in events, "langgraph_runtime never drained in-flight runs on shutdown"
assert "checkpointer_closed" in events
assert events.index("runs_drained") < events.index("checkpointer_closed"), f"runs must be drained before the checkpointer pool is closed; got order {events}"
@pytest.mark.asyncio
async def test_drain_flushes_real_graph_checkpoint_before_close():
"""End-to-end #3373 guard with a REAL langgraph graph + checkpointer.
A real run is driven through ``graph.astream`` in a background task, then
``RunManager.shutdown()`` drains it. The checkpointer raises once closed
(mirroring ``psycopg_pool.PoolClosed``). Closing only happens AFTER the
drain as the gateway's AsyncExitStack does. The drain must let langgraph
flush its final checkpoint while the checkpointer is still open, so no write
lands against a closed checkpointer.
Unlike the unit/spy tests above, this exercises the real langgraph
checkpoint-put machinery, so a future langgraph change that cancels (rather
than awaits) its checkpoint-put task on executor exit would fail this test
instead of silently regressing #3373.
"""
from langgraph.graph import END, START, StateGraph
async def slow(_state: _CountState) -> dict:
await asyncio.sleep(0.1)
return {"count": 1}
saver = _CloseableSaver()
builder = StateGraph(_CountState)
for name in ("a", "b", "c"):
builder.add_node(name, slow)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
builder.add_edge("c", END)
graph = builder.compile(checkpointer=saver)
rm = RunManager()
record = await rm.create("t-e2e")
await rm.set_status(record.run_id, RunStatus.running)
thread_cfg = {"configurable": {"thread_id": "t-e2e"}}
started = asyncio.Event()
async def run() -> None:
started.set()
async for _ in graph.astream({"count": 0}, config=thread_cfg):
pass
record.task = asyncio.create_task(run())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
# Deterministically wait until the run is genuinely in-flight — poll for
# the first persisted checkpoint instead of a fixed sleep (avoids CI
# flakiness on slow runners / under event-loop contention).
async def _await_first_checkpoint() -> None:
while (await saver.aget_tuple(thread_cfg)) is None:
await asyncio.sleep(0.01)
await asyncio.wait_for(_await_first_checkpoint(), timeout=5.0)
# The fix: drain while the checkpointer is still open ...
await rm.shutdown(timeout=5.0)
# ... and only then close it (mirrors langgraph_runtime's ExitStack).
saver.close()
assert saver.writes_after_close == [], f"a checkpoint write raced a closed checkpointer: {saver.writes_after_close}"
# The final checkpoint landed before close.
snapshot = await saver.aget_tuple(thread_cfg)
assert snapshot is not None
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_preserves_status_of_run_completed_during_drain():
"""A run that finishes (e.g. success) during the drain window must keep its
real terminal status shutdown must not blanket-overwrite it to
``interrupted`` in memory or in the store (Copilot review on PR #3381)."""
from deerflow.runtime.runs.store.memory import MemoryRunStore
store = MemoryRunStore()
rm = RunManager(store=store)
record = await rm.create("t-complete")
await rm.set_status(record.run_id, RunStatus.running)
async def worker() -> None:
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
# The run had effectively finished; swallow the cancellation and
# record success, like a run that completed in the same tick the
# shutdown cancelled it.
pass
await rm.set_status(record.run_id, RunStatus.success)
record.task = asyncio.create_task(worker())
try:
await asyncio.sleep(0) # let the task reach its await point
await rm.shutdown(timeout=5.0)
assert record.status == RunStatus.success, f"shutdown overwrote in-memory status: {record.status}"
persisted = await store.get(record.run_id)
assert persisted is not None and persisted["status"] == "success", f"shutdown overwrote persisted status: {persisted}"
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_surfaces_failed_interrupted_persist(caplog):
"""A failed interrupted-status persist during the drain must be surfaced (with
the run_id), not silently swallowed by the gather (maintainer review on
PR #3381)."""
import logging
from deerflow.runtime.runs.store.memory import MemoryRunStore
class _FailingStore(MemoryRunStore):
async def update_status(self, *args, **kwargs):
raise RuntimeError("store unavailable")
rm = RunManager(store=_FailingStore())
record = await rm.create("t-failpersist")
record.status = RunStatus.running # set in memory; the failing store is exercised by the drain
started = asyncio.Event()
async def worker() -> None:
started.set()
await asyncio.Event().wait() # blocks until cancelled by the drain
record.task = asyncio.create_task(worker())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
with caplog.at_level(logging.WARNING, logger="deerflow.runtime.runs.manager"):
await rm.shutdown(timeout=5.0)
assert "Could not persist interrupted status for run" in caplog.text, caplog.text
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task

View File

@ -32,6 +32,7 @@ class _FakeRunManager:
self.store = store
self.reconcile_calls: list[dict] = []
self.list_by_thread_calls: list[dict] = []
self.shutdown_calls: int = 0
_FakeRunManager.instances.append(self)
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
@ -42,6 +43,11 @@ class _FakeRunManager:
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
async def shutdown(self, *, timeout: float = 5.0) -> None:
# No in-flight tasks in these startup-recovery tests; langgraph_runtime
# drains the manager on teardown, so the double must accept the call.
self.shutdown_calls += 1
class _FakeThreadStore:
def __init__(self) -> None: