deer-flow/backend/tests/test_gateway_run_drain_shutdown.py
Xinmin Zeng 268fdd6968
fix(gateway): drain in-flight runs before closing checkpointer on shutdown (#3381)
* fix(gateway): drain in-flight runs before closing checkpointer on shutdown

Chat runs execute in fire-and-forget background asyncio tasks that write
checkpoints through a shared checkpointer. On shutdown, langgraph_runtime's
AsyncExitStack tore down the checkpointer's postgres connection pool while
those run tasks were still mid-graph. langgraph's
AsyncPregelLoop._checkpointer_put_after_previous then ran its
`finally: await checkpointer.aput(...)` against the closed pool, raising
psycopg_pool.PoolClosed. Because that put runs in a langgraph-internal task
(not on run_agent's call stack), run_agent's try/except cannot catch it and it
surfaces as "unhandled exception during asyncio.run() shutdown".

Add RunManager.shutdown() to cancel and bounded-await all in-flight runs, and
call it from langgraph_runtime BEFORE the AsyncExitStack closes the
checkpointer, so the final checkpoint write lands while the pool is still open.
The drain is bounded by a timeout so a stuck run cannot hang worker shutdown,
and is shielded so a second shutdown signal cannot abandon it mid-drain and
reopen the race.

Closes #3373

* fix(gateway): address review — preserve completed-run status, bound drain persistence

Addresses Copilot review on #3381:

- RunManager.shutdown(): decide run status AFTER the drain. Under the lock it
  now only requests cancellation; after asyncio.wait it marks/persists
  `interrupted` only for runs still pending or ended cancelled. A run that
  completes (e.g. `success`) during the drain window keeps its real terminal
  status instead of being unconditionally overwritten.
- Bound the trailing status persistence within the timeout budget
  (deadline = loop.time()+timeout; gather wrapped in asyncio.wait_for) so a slow
  store backing off under DB pressure cannot push shutdown past the deadline.
- deps: use asyncio.create_task instead of asyncio.ensure_future.
- tests: wait deterministically for the run to be in-flight (poll the first
  checkpoint) instead of a fixed sleep; init shutdown_calls explicitly in the
  recovery test double; add regression test asserting a run completing during
  the drain keeps its status (in memory and in the store).

* fix(gateway): address maintainer review — surface failed drain persists, clarify timeout constant

Addresses @WillemJiang review on #3381:

- shutdown(): inspect the gather result of the trailing interrupted-status
  persistence. _persist_status is best-effort (it catches + logs its own
  failure with exc_info and returns False, so it never raises out of the
  gather), but the aggregate result was never checked — a partial failure had
  no shutdown-level visibility. Now any escaped Exception is logged, and any
  False (a persist that did not confirm) is logged with the run_id. Added
  regression test test_shutdown_surfaces_failed_interrupted_persist.
- deps: clarify the _RUN_DRAIN_TIMEOUT_SECONDS comment — state the actual value
  of _SHUTDOWN_HOOK_TIMEOUT_SECONDS (5.0s) and that both count toward the
  lifespan shutdown window. Kept as two separate constants (independent teardown
  steps that may diverge) rather than one shared "must match" value.
- Verified no other test fake needs the shutdown stub: _FakeRunManager in
  test_worker_langfuse_metadata.py is a run_agent() argument (worker path),
  never injected into langgraph_runtime, so it never receives shutdown().
2026-06-07 11:24:30 +08:00

354 lines
13 KiB
Python

"""Regression tests for graceful run-task drain on Gateway shutdown.
Guards bytedance/deer-flow issue #3373:
psycopg_pool.PoolClosed: the pool 'pool-1' is already closed
Root cause: chat runs are fire-and-forget background ``asyncio`` tasks
(``app/gateway/services.py`` -> ``asyncio.create_task(run_agent(...))``) owned
by nobody. On shutdown, ``langgraph_runtime``'s ``AsyncExitStack`` tore down the
checkpointer's postgres pool while those tasks were still mid-graph. langgraph's
``AsyncPregelLoop._checkpointer_put_after_previous`` then ran its
``finally: await checkpointer.aput(...)`` against the already-closed pool.
Fix: ``RunManager.shutdown()`` cancels and *bounded*-awaits every in-flight run,
and ``langgraph_runtime`` calls it BEFORE the ``AsyncExitStack`` closes the
checkpointer — so the final checkpoint write lands while the pool is still open.
The drain must stay bounded (a stuck run must not hang the worker, the
precondition for the signal-reentrancy deadlock guarded by
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``).
"""
from __future__ import annotations
import asyncio
import operator
from contextlib import asynccontextmanager, suppress
from types import SimpleNamespace
from typing import Annotated, TypedDict
import pytest
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime import RunManager, RunStatus
# Module-level so langgraph's get_type_hints (which resolves annotations against
# module globals under `from __future__ import annotations`) can see Annotated.
class _CountState(TypedDict):
count: Annotated[int, operator.add]
class _CloseableSaver(InMemorySaver):
"""InMemorySaver that fails writes once closed, like a closed pool."""
def __init__(self) -> None:
super().__init__()
self._closed = False
self.writes_after_close: list[str] = []
def close(self) -> None:
self._closed = True
async def aput(self, *args, **kwargs):
if self._closed:
self.writes_after_close.append("aput")
raise RuntimeError("checkpointer is closed")
return await super().aput(*args, **kwargs)
async def aput_writes(self, *args, **kwargs):
if self._closed:
self.writes_after_close.append("aput_writes")
raise RuntimeError("checkpointer is closed")
return await super().aput_writes(*args, **kwargs)
@pytest.mark.asyncio
async def test_shutdown_cancels_and_awaits_inflight_run():
"""shutdown() cancels the in-flight task, waits for it, marks it interrupted."""
rm = RunManager()
record = await rm.create("t-drain")
await rm.set_status(record.run_id, RunStatus.running)
started = asyncio.Event()
cancelled = asyncio.Event()
async def worker() -> None:
try:
started.set()
await asyncio.Event().wait()
except asyncio.CancelledError:
cancelled.set()
raise
record.task = asyncio.create_task(worker())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
await rm.shutdown(timeout=5.0)
assert record.task.done()
assert cancelled.is_set()
assert record.status == RunStatus.interrupted
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_is_bounded_when_run_ignores_cancellation():
"""A run that swallows cancellation must not make shutdown() hang."""
rm = RunManager()
record = await rm.create("t-stubborn")
await rm.set_status(record.run_id, RunStatus.running)
started = asyncio.Event()
stop = asyncio.Event()
async def stubborn() -> None:
started.set()
while not stop.is_set():
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
if stop.is_set():
raise
# else: swallow — simulates a run stuck in slow cleanup
record.task = asyncio.create_task(stubborn())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
loop = asyncio.get_running_loop()
t0 = loop.time()
await rm.shutdown(timeout=0.3)
elapsed = loop.time() - t0
assert elapsed < 2.0, f"shutdown took {elapsed:.2f}s; drain is not bounded"
finally:
# cleanup the deliberately-stubborn task
stop.set()
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_is_noop_without_inflight_runs():
"""shutdown() on an idle manager completes cleanly and is idempotent."""
rm = RunManager()
await rm.shutdown(timeout=1.0)
# already-finished runs must not be re-cancelled or error out
record = await rm.create("t-done")
await rm.set_status(record.run_id, RunStatus.success)
await rm.shutdown(timeout=1.0)
@pytest.mark.asyncio
async def test_langgraph_runtime_drains_runs_before_closing_checkpointer(monkeypatch):
"""The wiring order lock for #3373: drain in-flight runs, THEN close the pool.
Patches every ``langgraph_runtime`` collaborator down to trivial stand-ins so
only the bootstrap/teardown ordering runs. The checkpointer probe records when
its context manager exits (pool close); a ``RunManager.shutdown`` spy records
when the drain happens. The drain MUST come first.
"""
from fastapi import FastAPI
from app.gateway.deps import langgraph_runtime
events: list[str] = []
@asynccontextmanager
async def probe_checkpointer(_config):
try:
yield object()
finally:
events.append("checkpointer_closed")
@asynccontextmanager
async def fake_stream_bridge(_config):
yield object()
@asynccontextmanager
async def fake_store(_config):
yield object()
async def fake_init_engine(_db):
return None
async def fake_close_engine():
return None
async def spy_shutdown(self, *, timeout): # noqa: ANN001
events.append("runs_drained")
monkeypatch.setattr("deerflow.runtime.checkpointer.async_provider.make_checkpointer", probe_checkpointer)
monkeypatch.setattr("deerflow.runtime.make_stream_bridge", fake_stream_bridge)
monkeypatch.setattr("deerflow.runtime.make_store", fake_store)
monkeypatch.setattr("deerflow.persistence.engine.init_engine_from_config", fake_init_engine)
monkeypatch.setattr("deerflow.persistence.engine.close_engine", fake_close_engine)
monkeypatch.setattr("deerflow.persistence.engine.get_session_factory", lambda: None)
monkeypatch.setattr("deerflow.runtime.events.store.make_run_event_store", lambda _cfg: object())
monkeypatch.setattr("deerflow.persistence.thread_meta.make_thread_store", lambda _sf, _store: object())
monkeypatch.setattr(RunManager, "shutdown", spy_shutdown, raising=False)
app = FastAPI()
startup_config = SimpleNamespace(database=SimpleNamespace(backend="memory"), run_events=None)
async with langgraph_runtime(app, startup_config):
pass
assert "runs_drained" in events, "langgraph_runtime never drained in-flight runs on shutdown"
assert "checkpointer_closed" in events
assert events.index("runs_drained") < events.index("checkpointer_closed"), f"runs must be drained before the checkpointer pool is closed; got order {events}"
@pytest.mark.asyncio
async def test_drain_flushes_real_graph_checkpoint_before_close():
"""End-to-end #3373 guard with a REAL langgraph graph + checkpointer.
A real run is driven through ``graph.astream`` in a background task, then
``RunManager.shutdown()`` drains it. The checkpointer raises once closed
(mirroring ``psycopg_pool.PoolClosed``). Closing only happens AFTER the
drain — as the gateway's AsyncExitStack does. The drain must let langgraph
flush its final checkpoint while the checkpointer is still open, so no write
lands against a closed checkpointer.
Unlike the unit/spy tests above, this exercises the real langgraph
checkpoint-put machinery, so a future langgraph change that cancels (rather
than awaits) its checkpoint-put task on executor exit would fail this test
instead of silently regressing #3373.
"""
from langgraph.graph import END, START, StateGraph
async def slow(_state: _CountState) -> dict:
await asyncio.sleep(0.1)
return {"count": 1}
saver = _CloseableSaver()
builder = StateGraph(_CountState)
for name in ("a", "b", "c"):
builder.add_node(name, slow)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", "c")
builder.add_edge("c", END)
graph = builder.compile(checkpointer=saver)
rm = RunManager()
record = await rm.create("t-e2e")
await rm.set_status(record.run_id, RunStatus.running)
thread_cfg = {"configurable": {"thread_id": "t-e2e"}}
started = asyncio.Event()
async def run() -> None:
started.set()
async for _ in graph.astream({"count": 0}, config=thread_cfg):
pass
record.task = asyncio.create_task(run())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
# Deterministically wait until the run is genuinely in-flight — poll for
# the first persisted checkpoint instead of a fixed sleep (avoids CI
# flakiness on slow runners / under event-loop contention).
async def _await_first_checkpoint() -> None:
while (await saver.aget_tuple(thread_cfg)) is None:
await asyncio.sleep(0.01)
await asyncio.wait_for(_await_first_checkpoint(), timeout=5.0)
# The fix: drain while the checkpointer is still open ...
await rm.shutdown(timeout=5.0)
# ... and only then close it (mirrors langgraph_runtime's ExitStack).
saver.close()
assert saver.writes_after_close == [], f"a checkpoint write raced a closed checkpointer: {saver.writes_after_close}"
# The final checkpoint landed before close.
snapshot = await saver.aget_tuple(thread_cfg)
assert snapshot is not None
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_preserves_status_of_run_completed_during_drain():
"""A run that finishes (e.g. success) during the drain window must keep its
real terminal status — shutdown must not blanket-overwrite it to
``interrupted`` in memory or in the store (Copilot review on PR #3381)."""
from deerflow.runtime.runs.store.memory import MemoryRunStore
store = MemoryRunStore()
rm = RunManager(store=store)
record = await rm.create("t-complete")
await rm.set_status(record.run_id, RunStatus.running)
async def worker() -> None:
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
# The run had effectively finished; swallow the cancellation and
# record success, like a run that completed in the same tick the
# shutdown cancelled it.
pass
await rm.set_status(record.run_id, RunStatus.success)
record.task = asyncio.create_task(worker())
try:
await asyncio.sleep(0) # let the task reach its await point
await rm.shutdown(timeout=5.0)
assert record.status == RunStatus.success, f"shutdown overwrote in-memory status: {record.status}"
persisted = await store.get(record.run_id)
assert persisted is not None and persisted["status"] == "success", f"shutdown overwrote persisted status: {persisted}"
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task
@pytest.mark.asyncio
async def test_shutdown_surfaces_failed_interrupted_persist(caplog):
"""A failed interrupted-status persist during the drain must be surfaced (with
the run_id), not silently swallowed by the gather (maintainer review on
PR #3381)."""
import logging
from deerflow.runtime.runs.store.memory import MemoryRunStore
class _FailingStore(MemoryRunStore):
async def update_status(self, *args, **kwargs):
raise RuntimeError("store unavailable")
rm = RunManager(store=_FailingStore())
record = await rm.create("t-failpersist")
record.status = RunStatus.running # set in memory; the failing store is exercised by the drain
started = asyncio.Event()
async def worker() -> None:
started.set()
await asyncio.Event().wait() # blocks until cancelled by the drain
record.task = asyncio.create_task(worker())
try:
await asyncio.wait_for(started.wait(), timeout=1.0)
with caplog.at_level(logging.WARNING, logger="deerflow.runtime.runs.manager"):
await rm.shutdown(timeout=5.0)
assert "Could not persist interrupted status for run" in caplog.text, caplog.text
finally:
if not record.task.done():
record.task.cancel()
with suppress(asyncio.CancelledError):
await record.task