diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 7f9674070..24af1deeb 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -17,6 +17,7 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. from __future__ import annotations +import asyncio import logging from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager @@ -33,6 +34,43 @@ from deerflow.runtime.runs.store.base import RunStore logger = logging.getLogger(__name__) +# Upper bound (seconds) for draining in-flight runs during shutdown, before the +# AsyncExitStack tears down the checkpointer (and its connection pool). Kept +# local to avoid an app -> deps -> app import cycle. This is a *separate* budget +# from ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS`` (currently also 5.0s, +# which bounds channel-service stop): the two govern independent teardown steps +# and may diverge, but both count toward the lifespan shutdown window — revisit +# them together if their sum must stay within the server's graceful-shutdown +# timeout. +_RUN_DRAIN_TIMEOUT_SECONDS = 5.0 + + +async def _drain_inflight_runs(run_manager: RunManager) -> None: + """Drain in-flight runs before the checkpointer is torn down (issue #3373). + + Shields the (internally-bounded) drain so that even if the lifespan + coroutine is itself cancelled mid-shutdown — a second SIGINT or the server's + graceful-shutdown timeout, i.e. the same signal storm behind #3373 — the + checkpointer pool is not closed while run tasks are still writing + checkpoints. On such a cancellation we let the already-running drain finish + (it is bounded by ``RunManager.shutdown``'s own timeout) and then propagate + the cancellation. + """ + drain = asyncio.create_task(run_manager.shutdown(timeout=_RUN_DRAIN_TIMEOUT_SECONDS)) + try: + await asyncio.shield(drain) + except asyncio.CancelledError: + # Re-shield so this second wait does not abandon the in-flight drain; + # it is bounded, so this cannot hang. Then re-raise to honour shutdown. + try: + await asyncio.shield(drain) + except Exception: + logger.exception("In-flight run drain failed after shutdown cancellation") + raise + except Exception: + logger.exception("Failed to drain in-flight runs during shutdown") + + if TYPE_CHECKING: from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.repositories.sqlite import SQLiteUserRepository @@ -177,6 +215,14 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen try: yield finally: + # Drain in-flight run tasks BEFORE the AsyncExitStack tears down the + # checkpointer (and its connection pool). A run still mid-graph would + # otherwise leak into asyncio.run() shutdown, where langgraph's + # _checkpointer_put_after_previous aput races the closed pool and + # raises PoolClosed (issue #3373). + run_manager = getattr(app.state, "run_manager", None) + if run_manager is not None: + await _drain_inflight_runs(run_manager) await close_engine() diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 41abe6495..ef45852fb 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -645,6 +645,98 @@ class RunManager: self._runs.pop(run_id, None) logger.debug("Run record %s cleaned up", run_id) + async def shutdown(self, *, timeout: float = 5.0) -> None: + """Cancel and bounded-await all in-flight runs on process shutdown. + + Chat runs execute in fire-and-forget background ``asyncio`` tasks that + write checkpoints through a shared checkpointer. On shutdown the + checkpointer's resources (e.g. the postgres connection pool owned by the + gateway's ``AsyncExitStack``) are torn down; if a run task is still + mid-graph at that point, langgraph's + ``AsyncPregelLoop._checkpointer_put_after_previous`` runs its + ``finally: await checkpointer.aput(...)`` against the closed pool. Because + that put runs in a langgraph-internal task (not on ``run_agent``'s call + stack), the resulting ``psycopg_pool.PoolClosed`` is not catchable by the + worker and surfaces as an unhandled exception during ``asyncio.run()`` + shutdown (bytedance/deer-flow issue #3373). + + Draining in-flight runs *before* the checkpointer is closed lets each + run that settles within ``timeout`` flush its final checkpoint while + resources are still open. Only runs that do **not** settle on their own + are marked ``interrupted`` — a run that completes (e.g. ``success``) + during the drain keeps its real terminal status instead of being + blanket-overwritten. The whole drain, including the trailing status + persistence, is bounded by ``timeout`` so a run stuck in cleanup (or a + slow store under DB pressure) cannot hang worker shutdown — the + precondition for the signal-reentrancy deadlock guarded by + ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``. Runs still active + after ``timeout`` are logged and may still race teardown. + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + + async with self._lock: + inflight = [record for record in self._runs.values() if record.status in (RunStatus.pending, RunStatus.running) and record.task is not None and not record.task.done()] + for record in inflight: + record.abort_action = "interrupt" + record.abort_event.set() + record.task.cancel() # type: ignore[union-attr] # filtered above + # Status is decided AFTER the drain (below), not here: a run that + # completes on its own during the drain must keep its real status. + + if not inflight: + return + + tasks = [record.task for record in inflight] + _, pending = await asyncio.wait(tasks, timeout=timeout) + + # Only mark/persist ``interrupted`` for runs that did not settle on their + # own (still pending after the timeout, or ended cancelled). A run that + # finished normally during the drain keeps the status it set for itself. + to_persist: list[RunRecord] = [] + async with self._lock: + for record in inflight: + task = record.task + if task not in pending and not task.cancelled(): + # Completed on its own — retrieve any surfaced exception so it + # is not reported as "never retrieved", and keep its status. + task.exception() # type: ignore[union-attr] # done & not cancelled + continue + if record.status in (RunStatus.pending, RunStatus.running): + record.status = RunStatus.interrupted + record.updated_at = _now_iso() + to_persist.append(record) + + # Bound the trailing status persistence within the remaining budget so a + # slow store (``_call_store_with_retry`` can back off under DB pressure) + # cannot push shutdown past ``timeout``. + if to_persist: + remaining = deadline - loop.time() + if remaining <= 0: + logger.warning("Run drain budget exhausted before persisting %d interrupted run(s) on shutdown", len(to_persist)) + else: + try: + results = await asyncio.wait_for( + asyncio.gather(*(self._persist_status(record, RunStatus.interrupted) for record in to_persist), return_exceptions=True), + timeout=remaining, + ) + except TimeoutError: + logger.warning("Run drain status persistence exceeded the %.1fs budget; %d record(s) may not be persisted", timeout, len(to_persist)) + else: + # ``_persist_status`` is best-effort: it catches and logs its + # own failures, returning ``False``. Inspect the aggregate so a + # partial failure is surfaced at shutdown level (with the + # run_id) instead of being silently swallowed by the gather. + for record, result in zip(to_persist, results): + if isinstance(result, Exception): + logger.warning("Unexpected error persisting interrupted status for run %s during shutdown: %r", record.run_id, result) + elif result is False: + logger.warning("Could not persist interrupted status for run %s during shutdown", record.run_id) + + if pending: + logger.warning("Run drain exceeded %.1fs on shutdown; %d run task(s) still active and may race checkpointer teardown", timeout, len(pending)) + logger.info("Drained %d in-flight run(s) on shutdown (%d settled within %.1fs)", len(inflight), len(inflight) - len(pending), timeout) + class ConflictError(Exception): """Raised when multitask_strategy=reject and thread has inflight runs.""" diff --git a/backend/tests/test_gateway_run_drain_shutdown.py b/backend/tests/test_gateway_run_drain_shutdown.py new file mode 100644 index 000000000..9d9ce8940 --- /dev/null +++ b/backend/tests/test_gateway_run_drain_shutdown.py @@ -0,0 +1,353 @@ +"""Regression tests for graceful run-task drain on Gateway shutdown. + +Guards bytedance/deer-flow issue #3373: + + psycopg_pool.PoolClosed: the pool 'pool-1' is already closed + +Root cause: chat runs are fire-and-forget background ``asyncio`` tasks +(``app/gateway/services.py`` -> ``asyncio.create_task(run_agent(...))``) owned +by nobody. On shutdown, ``langgraph_runtime``'s ``AsyncExitStack`` tore down the +checkpointer's postgres pool while those tasks were still mid-graph. langgraph's +``AsyncPregelLoop._checkpointer_put_after_previous`` then ran its +``finally: await checkpointer.aput(...)`` against the already-closed pool. + +Fix: ``RunManager.shutdown()`` cancels and *bounded*-awaits every in-flight run, +and ``langgraph_runtime`` calls it BEFORE the ``AsyncExitStack`` closes the +checkpointer — so the final checkpoint write lands while the pool is still open. +The drain must stay bounded (a stuck run must not hang the worker, the +precondition for the signal-reentrancy deadlock guarded by +``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``). +""" + +from __future__ import annotations + +import asyncio +import operator +from contextlib import asynccontextmanager, suppress +from types import SimpleNamespace +from typing import Annotated, TypedDict + +import pytest +from langgraph.checkpoint.memory import InMemorySaver + +from deerflow.runtime import RunManager, RunStatus + + +# Module-level so langgraph's get_type_hints (which resolves annotations against +# module globals under `from __future__ import annotations`) can see Annotated. +class _CountState(TypedDict): + count: Annotated[int, operator.add] + + +class _CloseableSaver(InMemorySaver): + """InMemorySaver that fails writes once closed, like a closed pool.""" + + def __init__(self) -> None: + super().__init__() + self._closed = False + self.writes_after_close: list[str] = [] + + def close(self) -> None: + self._closed = True + + async def aput(self, *args, **kwargs): + if self._closed: + self.writes_after_close.append("aput") + raise RuntimeError("checkpointer is closed") + return await super().aput(*args, **kwargs) + + async def aput_writes(self, *args, **kwargs): + if self._closed: + self.writes_after_close.append("aput_writes") + raise RuntimeError("checkpointer is closed") + return await super().aput_writes(*args, **kwargs) + + +@pytest.mark.asyncio +async def test_shutdown_cancels_and_awaits_inflight_run(): + """shutdown() cancels the in-flight task, waits for it, marks it interrupted.""" + rm = RunManager() + record = await rm.create("t-drain") + await rm.set_status(record.run_id, RunStatus.running) + + started = asyncio.Event() + cancelled = asyncio.Event() + + async def worker() -> None: + try: + started.set() + await asyncio.Event().wait() + except asyncio.CancelledError: + cancelled.set() + raise + + record.task = asyncio.create_task(worker()) + try: + await asyncio.wait_for(started.wait(), timeout=1.0) + + await rm.shutdown(timeout=5.0) + + assert record.task.done() + assert cancelled.is_set() + assert record.status == RunStatus.interrupted + finally: + if not record.task.done(): + record.task.cancel() + with suppress(asyncio.CancelledError): + await record.task + + +@pytest.mark.asyncio +async def test_shutdown_is_bounded_when_run_ignores_cancellation(): + """A run that swallows cancellation must not make shutdown() hang.""" + rm = RunManager() + record = await rm.create("t-stubborn") + await rm.set_status(record.run_id, RunStatus.running) + + started = asyncio.Event() + stop = asyncio.Event() + + async def stubborn() -> None: + started.set() + while not stop.is_set(): + try: + await asyncio.sleep(3600) + except asyncio.CancelledError: + if stop.is_set(): + raise + # else: swallow — simulates a run stuck in slow cleanup + + record.task = asyncio.create_task(stubborn()) + try: + await asyncio.wait_for(started.wait(), timeout=1.0) + + loop = asyncio.get_running_loop() + t0 = loop.time() + await rm.shutdown(timeout=0.3) + elapsed = loop.time() - t0 + + assert elapsed < 2.0, f"shutdown took {elapsed:.2f}s; drain is not bounded" + finally: + # cleanup the deliberately-stubborn task + stop.set() + record.task.cancel() + with suppress(asyncio.CancelledError): + await record.task + + +@pytest.mark.asyncio +async def test_shutdown_is_noop_without_inflight_runs(): + """shutdown() on an idle manager completes cleanly and is idempotent.""" + rm = RunManager() + await rm.shutdown(timeout=1.0) + # already-finished runs must not be re-cancelled or error out + record = await rm.create("t-done") + await rm.set_status(record.run_id, RunStatus.success) + await rm.shutdown(timeout=1.0) + + +@pytest.mark.asyncio +async def test_langgraph_runtime_drains_runs_before_closing_checkpointer(monkeypatch): + """The wiring order lock for #3373: drain in-flight runs, THEN close the pool. + + Patches every ``langgraph_runtime`` collaborator down to trivial stand-ins so + only the bootstrap/teardown ordering runs. The checkpointer probe records when + its context manager exits (pool close); a ``RunManager.shutdown`` spy records + when the drain happens. The drain MUST come first. + """ + from fastapi import FastAPI + + from app.gateway.deps import langgraph_runtime + + events: list[str] = [] + + @asynccontextmanager + async def probe_checkpointer(_config): + try: + yield object() + finally: + events.append("checkpointer_closed") + + @asynccontextmanager + async def fake_stream_bridge(_config): + yield object() + + @asynccontextmanager + async def fake_store(_config): + yield object() + + async def fake_init_engine(_db): + return None + + async def fake_close_engine(): + return None + + async def spy_shutdown(self, *, timeout): # noqa: ANN001 + events.append("runs_drained") + + monkeypatch.setattr("deerflow.runtime.checkpointer.async_provider.make_checkpointer", probe_checkpointer) + monkeypatch.setattr("deerflow.runtime.make_stream_bridge", fake_stream_bridge) + monkeypatch.setattr("deerflow.runtime.make_store", fake_store) + monkeypatch.setattr("deerflow.persistence.engine.init_engine_from_config", fake_init_engine) + monkeypatch.setattr("deerflow.persistence.engine.close_engine", fake_close_engine) + monkeypatch.setattr("deerflow.persistence.engine.get_session_factory", lambda: None) + monkeypatch.setattr("deerflow.runtime.events.store.make_run_event_store", lambda _cfg: object()) + monkeypatch.setattr("deerflow.persistence.thread_meta.make_thread_store", lambda _sf, _store: object()) + monkeypatch.setattr(RunManager, "shutdown", spy_shutdown, raising=False) + + app = FastAPI() + startup_config = SimpleNamespace(database=SimpleNamespace(backend="memory"), run_events=None) + + async with langgraph_runtime(app, startup_config): + pass + + assert "runs_drained" in events, "langgraph_runtime never drained in-flight runs on shutdown" + assert "checkpointer_closed" in events + assert events.index("runs_drained") < events.index("checkpointer_closed"), f"runs must be drained before the checkpointer pool is closed; got order {events}" + + +@pytest.mark.asyncio +async def test_drain_flushes_real_graph_checkpoint_before_close(): + """End-to-end #3373 guard with a REAL langgraph graph + checkpointer. + + A real run is driven through ``graph.astream`` in a background task, then + ``RunManager.shutdown()`` drains it. The checkpointer raises once closed + (mirroring ``psycopg_pool.PoolClosed``). Closing only happens AFTER the + drain — as the gateway's AsyncExitStack does. The drain must let langgraph + flush its final checkpoint while the checkpointer is still open, so no write + lands against a closed checkpointer. + + Unlike the unit/spy tests above, this exercises the real langgraph + checkpoint-put machinery, so a future langgraph change that cancels (rather + than awaits) its checkpoint-put task on executor exit would fail this test + instead of silently regressing #3373. + """ + from langgraph.graph import END, START, StateGraph + + async def slow(_state: _CountState) -> dict: + await asyncio.sleep(0.1) + return {"count": 1} + + saver = _CloseableSaver() + builder = StateGraph(_CountState) + for name in ("a", "b", "c"): + builder.add_node(name, slow) + builder.add_edge(START, "a") + builder.add_edge("a", "b") + builder.add_edge("b", "c") + builder.add_edge("c", END) + graph = builder.compile(checkpointer=saver) + + rm = RunManager() + record = await rm.create("t-e2e") + await rm.set_status(record.run_id, RunStatus.running) + thread_cfg = {"configurable": {"thread_id": "t-e2e"}} + + started = asyncio.Event() + + async def run() -> None: + started.set() + async for _ in graph.astream({"count": 0}, config=thread_cfg): + pass + + record.task = asyncio.create_task(run()) + try: + await asyncio.wait_for(started.wait(), timeout=1.0) + + # Deterministically wait until the run is genuinely in-flight — poll for + # the first persisted checkpoint instead of a fixed sleep (avoids CI + # flakiness on slow runners / under event-loop contention). + async def _await_first_checkpoint() -> None: + while (await saver.aget_tuple(thread_cfg)) is None: + await asyncio.sleep(0.01) + + await asyncio.wait_for(_await_first_checkpoint(), timeout=5.0) + + # The fix: drain while the checkpointer is still open ... + await rm.shutdown(timeout=5.0) + # ... and only then close it (mirrors langgraph_runtime's ExitStack). + saver.close() + + assert saver.writes_after_close == [], f"a checkpoint write raced a closed checkpointer: {saver.writes_after_close}" + # The final checkpoint landed before close. + snapshot = await saver.aget_tuple(thread_cfg) + assert snapshot is not None + finally: + if not record.task.done(): + record.task.cancel() + with suppress(asyncio.CancelledError): + await record.task + + +@pytest.mark.asyncio +async def test_shutdown_preserves_status_of_run_completed_during_drain(): + """A run that finishes (e.g. success) during the drain window must keep its + real terminal status — shutdown must not blanket-overwrite it to + ``interrupted`` in memory or in the store (Copilot review on PR #3381).""" + from deerflow.runtime.runs.store.memory import MemoryRunStore + + store = MemoryRunStore() + rm = RunManager(store=store) + record = await rm.create("t-complete") + await rm.set_status(record.run_id, RunStatus.running) + + async def worker() -> None: + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + # The run had effectively finished; swallow the cancellation and + # record success, like a run that completed in the same tick the + # shutdown cancelled it. + pass + await rm.set_status(record.run_id, RunStatus.success) + + record.task = asyncio.create_task(worker()) + try: + await asyncio.sleep(0) # let the task reach its await point + + await rm.shutdown(timeout=5.0) + + assert record.status == RunStatus.success, f"shutdown overwrote in-memory status: {record.status}" + persisted = await store.get(record.run_id) + assert persisted is not None and persisted["status"] == "success", f"shutdown overwrote persisted status: {persisted}" + finally: + if not record.task.done(): + record.task.cancel() + with suppress(asyncio.CancelledError): + await record.task + + +@pytest.mark.asyncio +async def test_shutdown_surfaces_failed_interrupted_persist(caplog): + """A failed interrupted-status persist during the drain must be surfaced (with + the run_id), not silently swallowed by the gather (maintainer review on + PR #3381).""" + import logging + + from deerflow.runtime.runs.store.memory import MemoryRunStore + + class _FailingStore(MemoryRunStore): + async def update_status(self, *args, **kwargs): + raise RuntimeError("store unavailable") + + rm = RunManager(store=_FailingStore()) + record = await rm.create("t-failpersist") + record.status = RunStatus.running # set in memory; the failing store is exercised by the drain + + started = asyncio.Event() + + async def worker() -> None: + started.set() + await asyncio.Event().wait() # blocks until cancelled by the drain + + record.task = asyncio.create_task(worker()) + try: + await asyncio.wait_for(started.wait(), timeout=1.0) + with caplog.at_level(logging.WARNING, logger="deerflow.runtime.runs.manager"): + await rm.shutdown(timeout=5.0) + assert "Could not persist interrupted status for run" in caplog.text, caplog.text + finally: + if not record.task.done(): + record.task.cancel() + with suppress(asyncio.CancelledError): + await record.task diff --git a/backend/tests/test_gateway_run_recovery.py b/backend/tests/test_gateway_run_recovery.py index 4cabc2147..69b7c3d64 100644 --- a/backend/tests/test_gateway_run_recovery.py +++ b/backend/tests/test_gateway_run_recovery.py @@ -32,6 +32,7 @@ class _FakeRunManager: self.store = store self.reconcile_calls: list[dict] = [] self.list_by_thread_calls: list[dict] = [] + self.shutdown_calls: int = 0 _FakeRunManager.instances.append(self) async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None): @@ -42,6 +43,11 @@ class _FakeRunManager: self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit}) return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit]) + async def shutdown(self, *, timeout: float = 5.0) -> None: + # No in-flight tasks in these startup-recovery tests; langgraph_runtime + # drains the manager on teardown, so the double must accept the call. + self.shutdown_calls += 1 + class _FakeThreadStore: def __init__(self) -> None: