mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-09 17:12:01 +00:00
fix(gateway): drain in-flight runs before closing checkpointer on shutdown (#3381)
* fix(gateway): drain in-flight runs before closing checkpointer on shutdown Chat runs execute in fire-and-forget background asyncio tasks that write checkpoints through a shared checkpointer. On shutdown, langgraph_runtime's AsyncExitStack tore down the checkpointer's postgres connection pool while those run tasks were still mid-graph. langgraph's AsyncPregelLoop._checkpointer_put_after_previous then ran its `finally: await checkpointer.aput(...)` against the closed pool, raising psycopg_pool.PoolClosed. Because that put runs in a langgraph-internal task (not on run_agent's call stack), run_agent's try/except cannot catch it and it surfaces as "unhandled exception during asyncio.run() shutdown". Add RunManager.shutdown() to cancel and bounded-await all in-flight runs, and call it from langgraph_runtime BEFORE the AsyncExitStack closes the checkpointer, so the final checkpoint write lands while the pool is still open. The drain is bounded by a timeout so a stuck run cannot hang worker shutdown, and is shielded so a second shutdown signal cannot abandon it mid-drain and reopen the race. Closes #3373 * fix(gateway): address review — preserve completed-run status, bound drain persistence Addresses Copilot review on #3381: - RunManager.shutdown(): decide run status AFTER the drain. Under the lock it now only requests cancellation; after asyncio.wait it marks/persists `interrupted` only for runs still pending or ended cancelled. A run that completes (e.g. `success`) during the drain window keeps its real terminal status instead of being unconditionally overwritten. - Bound the trailing status persistence within the timeout budget (deadline = loop.time()+timeout; gather wrapped in asyncio.wait_for) so a slow store backing off under DB pressure cannot push shutdown past the deadline. - deps: use asyncio.create_task instead of asyncio.ensure_future. - tests: wait deterministically for the run to be in-flight (poll the first checkpoint) instead of a fixed sleep; init shutdown_calls explicitly in the recovery test double; add regression test asserting a run completing during the drain keeps its status (in memory and in the store). * fix(gateway): address maintainer review — surface failed drain persists, clarify timeout constant Addresses @WillemJiang review on #3381: - shutdown(): inspect the gather result of the trailing interrupted-status persistence. _persist_status is best-effort (it catches + logs its own failure with exc_info and returns False, so it never raises out of the gather), but the aggregate result was never checked — a partial failure had no shutdown-level visibility. Now any escaped Exception is logged, and any False (a persist that did not confirm) is logged with the run_id. Added regression test test_shutdown_surfaces_failed_interrupted_persist. - deps: clarify the _RUN_DRAIN_TIMEOUT_SECONDS comment — state the actual value of _SHUTDOWN_HOOK_TIMEOUT_SECONDS (5.0s) and that both count toward the lifespan shutdown window. Kept as two separate constants (independent teardown steps that may diverge) rather than one shared "must match" value. - Verified no other test fake needs the shutdown stub: _FakeRunManager in test_worker_langfuse_metadata.py is a run_agent() argument (worker path), never injected into langgraph_runtime, so it never receives shutdown().
This commit is contained in:
parent
9a5de8d6a5
commit
268fdd6968
@ -17,6 +17,7 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
@ -33,6 +34,43 @@ from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Upper bound (seconds) for draining in-flight runs during shutdown, before the
|
||||
# AsyncExitStack tears down the checkpointer (and its connection pool). Kept
|
||||
# local to avoid an app -> deps -> app import cycle. This is a *separate* budget
|
||||
# from ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS`` (currently also 5.0s,
|
||||
# which bounds channel-service stop): the two govern independent teardown steps
|
||||
# and may diverge, but both count toward the lifespan shutdown window — revisit
|
||||
# them together if their sum must stay within the server's graceful-shutdown
|
||||
# timeout.
|
||||
_RUN_DRAIN_TIMEOUT_SECONDS = 5.0
|
||||
|
||||
|
||||
async def _drain_inflight_runs(run_manager: RunManager) -> None:
|
||||
"""Drain in-flight runs before the checkpointer is torn down (issue #3373).
|
||||
|
||||
Shields the (internally-bounded) drain so that even if the lifespan
|
||||
coroutine is itself cancelled mid-shutdown — a second SIGINT or the server's
|
||||
graceful-shutdown timeout, i.e. the same signal storm behind #3373 — the
|
||||
checkpointer pool is not closed while run tasks are still writing
|
||||
checkpoints. On such a cancellation we let the already-running drain finish
|
||||
(it is bounded by ``RunManager.shutdown``'s own timeout) and then propagate
|
||||
the cancellation.
|
||||
"""
|
||||
drain = asyncio.create_task(run_manager.shutdown(timeout=_RUN_DRAIN_TIMEOUT_SECONDS))
|
||||
try:
|
||||
await asyncio.shield(drain)
|
||||
except asyncio.CancelledError:
|
||||
# Re-shield so this second wait does not abandon the in-flight drain;
|
||||
# it is bounded, so this cannot hang. Then re-raise to honour shutdown.
|
||||
try:
|
||||
await asyncio.shield(drain)
|
||||
except Exception:
|
||||
logger.exception("In-flight run drain failed after shutdown cancellation")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to drain in-flight runs during shutdown")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
@ -177,6 +215,14 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Drain in-flight run tasks BEFORE the AsyncExitStack tears down the
|
||||
# checkpointer (and its connection pool). A run still mid-graph would
|
||||
# otherwise leak into asyncio.run() shutdown, where langgraph's
|
||||
# _checkpointer_put_after_previous aput races the closed pool and
|
||||
# raises PoolClosed (issue #3373).
|
||||
run_manager = getattr(app.state, "run_manager", None)
|
||||
if run_manager is not None:
|
||||
await _drain_inflight_runs(run_manager)
|
||||
await close_engine()
|
||||
|
||||
|
||||
|
||||
@ -645,6 +645,98 @@ class RunManager:
|
||||
self._runs.pop(run_id, None)
|
||||
logger.debug("Run record %s cleaned up", run_id)
|
||||
|
||||
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
||||
"""Cancel and bounded-await all in-flight runs on process shutdown.
|
||||
|
||||
Chat runs execute in fire-and-forget background ``asyncio`` tasks that
|
||||
write checkpoints through a shared checkpointer. On shutdown the
|
||||
checkpointer's resources (e.g. the postgres connection pool owned by the
|
||||
gateway's ``AsyncExitStack``) are torn down; if a run task is still
|
||||
mid-graph at that point, langgraph's
|
||||
``AsyncPregelLoop._checkpointer_put_after_previous`` runs its
|
||||
``finally: await checkpointer.aput(...)`` against the closed pool. Because
|
||||
that put runs in a langgraph-internal task (not on ``run_agent``'s call
|
||||
stack), the resulting ``psycopg_pool.PoolClosed`` is not catchable by the
|
||||
worker and surfaces as an unhandled exception during ``asyncio.run()``
|
||||
shutdown (bytedance/deer-flow issue #3373).
|
||||
|
||||
Draining in-flight runs *before* the checkpointer is closed lets each
|
||||
run that settles within ``timeout`` flush its final checkpoint while
|
||||
resources are still open. Only runs that do **not** settle on their own
|
||||
are marked ``interrupted`` — a run that completes (e.g. ``success``)
|
||||
during the drain keeps its real terminal status instead of being
|
||||
blanket-overwritten. The whole drain, including the trailing status
|
||||
persistence, is bounded by ``timeout`` so a run stuck in cleanup (or a
|
||||
slow store under DB pressure) cannot hang worker shutdown — the
|
||||
precondition for the signal-reentrancy deadlock guarded by
|
||||
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``. Runs still active
|
||||
after ``timeout`` are logged and may still race teardown.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
|
||||
async with self._lock:
|
||||
inflight = [record for record in self._runs.values() if record.status in (RunStatus.pending, RunStatus.running) and record.task is not None and not record.task.done()]
|
||||
for record in inflight:
|
||||
record.abort_action = "interrupt"
|
||||
record.abort_event.set()
|
||||
record.task.cancel() # type: ignore[union-attr] # filtered above
|
||||
# Status is decided AFTER the drain (below), not here: a run that
|
||||
# completes on its own during the drain must keep its real status.
|
||||
|
||||
if not inflight:
|
||||
return
|
||||
|
||||
tasks = [record.task for record in inflight]
|
||||
_, pending = await asyncio.wait(tasks, timeout=timeout)
|
||||
|
||||
# Only mark/persist ``interrupted`` for runs that did not settle on their
|
||||
# own (still pending after the timeout, or ended cancelled). A run that
|
||||
# finished normally during the drain keeps the status it set for itself.
|
||||
to_persist: list[RunRecord] = []
|
||||
async with self._lock:
|
||||
for record in inflight:
|
||||
task = record.task
|
||||
if task not in pending and not task.cancelled():
|
||||
# Completed on its own — retrieve any surfaced exception so it
|
||||
# is not reported as "never retrieved", and keep its status.
|
||||
task.exception() # type: ignore[union-attr] # done & not cancelled
|
||||
continue
|
||||
if record.status in (RunStatus.pending, RunStatus.running):
|
||||
record.status = RunStatus.interrupted
|
||||
record.updated_at = _now_iso()
|
||||
to_persist.append(record)
|
||||
|
||||
# Bound the trailing status persistence within the remaining budget so a
|
||||
# slow store (``_call_store_with_retry`` can back off under DB pressure)
|
||||
# cannot push shutdown past ``timeout``.
|
||||
if to_persist:
|
||||
remaining = deadline - loop.time()
|
||||
if remaining <= 0:
|
||||
logger.warning("Run drain budget exhausted before persisting %d interrupted run(s) on shutdown", len(to_persist))
|
||||
else:
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*(self._persist_status(record, RunStatus.interrupted) for record in to_persist), return_exceptions=True),
|
||||
timeout=remaining,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Run drain status persistence exceeded the %.1fs budget; %d record(s) may not be persisted", timeout, len(to_persist))
|
||||
else:
|
||||
# ``_persist_status`` is best-effort: it catches and logs its
|
||||
# own failures, returning ``False``. Inspect the aggregate so a
|
||||
# partial failure is surfaced at shutdown level (with the
|
||||
# run_id) instead of being silently swallowed by the gather.
|
||||
for record, result in zip(to_persist, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning("Unexpected error persisting interrupted status for run %s during shutdown: %r", record.run_id, result)
|
||||
elif result is False:
|
||||
logger.warning("Could not persist interrupted status for run %s during shutdown", record.run_id)
|
||||
|
||||
if pending:
|
||||
logger.warning("Run drain exceeded %.1fs on shutdown; %d run task(s) still active and may race checkpointer teardown", timeout, len(pending))
|
||||
logger.info("Drained %d in-flight run(s) on shutdown (%d settled within %.1fs)", len(inflight), len(inflight) - len(pending), timeout)
|
||||
|
||||
|
||||
class ConflictError(Exception):
|
||||
"""Raised when multitask_strategy=reject and thread has inflight runs."""
|
||||
|
||||
353
backend/tests/test_gateway_run_drain_shutdown.py
Normal file
353
backend/tests/test_gateway_run_drain_shutdown.py
Normal file
@ -0,0 +1,353 @@
|
||||
"""Regression tests for graceful run-task drain on Gateway shutdown.
|
||||
|
||||
Guards bytedance/deer-flow issue #3373:
|
||||
|
||||
psycopg_pool.PoolClosed: the pool 'pool-1' is already closed
|
||||
|
||||
Root cause: chat runs are fire-and-forget background ``asyncio`` tasks
|
||||
(``app/gateway/services.py`` -> ``asyncio.create_task(run_agent(...))``) owned
|
||||
by nobody. On shutdown, ``langgraph_runtime``'s ``AsyncExitStack`` tore down the
|
||||
checkpointer's postgres pool while those tasks were still mid-graph. langgraph's
|
||||
``AsyncPregelLoop._checkpointer_put_after_previous`` then ran its
|
||||
``finally: await checkpointer.aput(...)`` against the already-closed pool.
|
||||
|
||||
Fix: ``RunManager.shutdown()`` cancels and *bounded*-awaits every in-flight run,
|
||||
and ``langgraph_runtime`` calls it BEFORE the ``AsyncExitStack`` closes the
|
||||
checkpointer — so the final checkpoint write lands while the pool is still open.
|
||||
The drain must stay bounded (a stuck run must not hang the worker, the
|
||||
precondition for the signal-reentrancy deadlock guarded by
|
||||
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import operator
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from types import SimpleNamespace
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.runtime import RunManager, RunStatus
|
||||
|
||||
|
||||
# Module-level so langgraph's get_type_hints (which resolves annotations against
|
||||
# module globals under `from __future__ import annotations`) can see Annotated.
|
||||
class _CountState(TypedDict):
|
||||
count: Annotated[int, operator.add]
|
||||
|
||||
|
||||
class _CloseableSaver(InMemorySaver):
|
||||
"""InMemorySaver that fails writes once closed, like a closed pool."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._closed = False
|
||||
self.writes_after_close: list[str] = []
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
async def aput(self, *args, **kwargs):
|
||||
if self._closed:
|
||||
self.writes_after_close.append("aput")
|
||||
raise RuntimeError("checkpointer is closed")
|
||||
return await super().aput(*args, **kwargs)
|
||||
|
||||
async def aput_writes(self, *args, **kwargs):
|
||||
if self._closed:
|
||||
self.writes_after_close.append("aput_writes")
|
||||
raise RuntimeError("checkpointer is closed")
|
||||
return await super().aput_writes(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_cancels_and_awaits_inflight_run():
|
||||
"""shutdown() cancels the in-flight task, waits for it, marks it interrupted."""
|
||||
rm = RunManager()
|
||||
record = await rm.create("t-drain")
|
||||
await rm.set_status(record.run_id, RunStatus.running)
|
||||
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def worker() -> None:
|
||||
try:
|
||||
started.set()
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
record.task = asyncio.create_task(worker())
|
||||
try:
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
await rm.shutdown(timeout=5.0)
|
||||
|
||||
assert record.task.done()
|
||||
assert cancelled.is_set()
|
||||
assert record.status == RunStatus.interrupted
|
||||
finally:
|
||||
if not record.task.done():
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_is_bounded_when_run_ignores_cancellation():
|
||||
"""A run that swallows cancellation must not make shutdown() hang."""
|
||||
rm = RunManager()
|
||||
record = await rm.create("t-stubborn")
|
||||
await rm.set_status(record.run_id, RunStatus.running)
|
||||
|
||||
started = asyncio.Event()
|
||||
stop = asyncio.Event()
|
||||
|
||||
async def stubborn() -> None:
|
||||
started.set()
|
||||
while not stop.is_set():
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
if stop.is_set():
|
||||
raise
|
||||
# else: swallow — simulates a run stuck in slow cleanup
|
||||
|
||||
record.task = asyncio.create_task(stubborn())
|
||||
try:
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
t0 = loop.time()
|
||||
await rm.shutdown(timeout=0.3)
|
||||
elapsed = loop.time() - t0
|
||||
|
||||
assert elapsed < 2.0, f"shutdown took {elapsed:.2f}s; drain is not bounded"
|
||||
finally:
|
||||
# cleanup the deliberately-stubborn task
|
||||
stop.set()
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_is_noop_without_inflight_runs():
|
||||
"""shutdown() on an idle manager completes cleanly and is idempotent."""
|
||||
rm = RunManager()
|
||||
await rm.shutdown(timeout=1.0)
|
||||
# already-finished runs must not be re-cancelled or error out
|
||||
record = await rm.create("t-done")
|
||||
await rm.set_status(record.run_id, RunStatus.success)
|
||||
await rm.shutdown(timeout=1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_langgraph_runtime_drains_runs_before_closing_checkpointer(monkeypatch):
|
||||
"""The wiring order lock for #3373: drain in-flight runs, THEN close the pool.
|
||||
|
||||
Patches every ``langgraph_runtime`` collaborator down to trivial stand-ins so
|
||||
only the bootstrap/teardown ordering runs. The checkpointer probe records when
|
||||
its context manager exits (pool close); a ``RunManager.shutdown`` spy records
|
||||
when the drain happens. The drain MUST come first.
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
@asynccontextmanager
|
||||
async def probe_checkpointer(_config):
|
||||
try:
|
||||
yield object()
|
||||
finally:
|
||||
events.append("checkpointer_closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_stream_bridge(_config):
|
||||
yield object()
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_store(_config):
|
||||
yield object()
|
||||
|
||||
async def fake_init_engine(_db):
|
||||
return None
|
||||
|
||||
async def fake_close_engine():
|
||||
return None
|
||||
|
||||
async def spy_shutdown(self, *, timeout): # noqa: ANN001
|
||||
events.append("runs_drained")
|
||||
|
||||
monkeypatch.setattr("deerflow.runtime.checkpointer.async_provider.make_checkpointer", probe_checkpointer)
|
||||
monkeypatch.setattr("deerflow.runtime.make_stream_bridge", fake_stream_bridge)
|
||||
monkeypatch.setattr("deerflow.runtime.make_store", fake_store)
|
||||
monkeypatch.setattr("deerflow.persistence.engine.init_engine_from_config", fake_init_engine)
|
||||
monkeypatch.setattr("deerflow.persistence.engine.close_engine", fake_close_engine)
|
||||
monkeypatch.setattr("deerflow.persistence.engine.get_session_factory", lambda: None)
|
||||
monkeypatch.setattr("deerflow.runtime.events.store.make_run_event_store", lambda _cfg: object())
|
||||
monkeypatch.setattr("deerflow.persistence.thread_meta.make_thread_store", lambda _sf, _store: object())
|
||||
monkeypatch.setattr(RunManager, "shutdown", spy_shutdown, raising=False)
|
||||
|
||||
app = FastAPI()
|
||||
startup_config = SimpleNamespace(database=SimpleNamespace(backend="memory"), run_events=None)
|
||||
|
||||
async with langgraph_runtime(app, startup_config):
|
||||
pass
|
||||
|
||||
assert "runs_drained" in events, "langgraph_runtime never drained in-flight runs on shutdown"
|
||||
assert "checkpointer_closed" in events
|
||||
assert events.index("runs_drained") < events.index("checkpointer_closed"), f"runs must be drained before the checkpointer pool is closed; got order {events}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_flushes_real_graph_checkpoint_before_close():
|
||||
"""End-to-end #3373 guard with a REAL langgraph graph + checkpointer.
|
||||
|
||||
A real run is driven through ``graph.astream`` in a background task, then
|
||||
``RunManager.shutdown()`` drains it. The checkpointer raises once closed
|
||||
(mirroring ``psycopg_pool.PoolClosed``). Closing only happens AFTER the
|
||||
drain — as the gateway's AsyncExitStack does. The drain must let langgraph
|
||||
flush its final checkpoint while the checkpointer is still open, so no write
|
||||
lands against a closed checkpointer.
|
||||
|
||||
Unlike the unit/spy tests above, this exercises the real langgraph
|
||||
checkpoint-put machinery, so a future langgraph change that cancels (rather
|
||||
than awaits) its checkpoint-put task on executor exit would fail this test
|
||||
instead of silently regressing #3373.
|
||||
"""
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
async def slow(_state: _CountState) -> dict:
|
||||
await asyncio.sleep(0.1)
|
||||
return {"count": 1}
|
||||
|
||||
saver = _CloseableSaver()
|
||||
builder = StateGraph(_CountState)
|
||||
for name in ("a", "b", "c"):
|
||||
builder.add_node(name, slow)
|
||||
builder.add_edge(START, "a")
|
||||
builder.add_edge("a", "b")
|
||||
builder.add_edge("b", "c")
|
||||
builder.add_edge("c", END)
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
|
||||
rm = RunManager()
|
||||
record = await rm.create("t-e2e")
|
||||
await rm.set_status(record.run_id, RunStatus.running)
|
||||
thread_cfg = {"configurable": {"thread_id": "t-e2e"}}
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def run() -> None:
|
||||
started.set()
|
||||
async for _ in graph.astream({"count": 0}, config=thread_cfg):
|
||||
pass
|
||||
|
||||
record.task = asyncio.create_task(run())
|
||||
try:
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
# Deterministically wait until the run is genuinely in-flight — poll for
|
||||
# the first persisted checkpoint instead of a fixed sleep (avoids CI
|
||||
# flakiness on slow runners / under event-loop contention).
|
||||
async def _await_first_checkpoint() -> None:
|
||||
while (await saver.aget_tuple(thread_cfg)) is None:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await asyncio.wait_for(_await_first_checkpoint(), timeout=5.0)
|
||||
|
||||
# The fix: drain while the checkpointer is still open ...
|
||||
await rm.shutdown(timeout=5.0)
|
||||
# ... and only then close it (mirrors langgraph_runtime's ExitStack).
|
||||
saver.close()
|
||||
|
||||
assert saver.writes_after_close == [], f"a checkpoint write raced a closed checkpointer: {saver.writes_after_close}"
|
||||
# The final checkpoint landed before close.
|
||||
snapshot = await saver.aget_tuple(thread_cfg)
|
||||
assert snapshot is not None
|
||||
finally:
|
||||
if not record.task.done():
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_preserves_status_of_run_completed_during_drain():
|
||||
"""A run that finishes (e.g. success) during the drain window must keep its
|
||||
real terminal status — shutdown must not blanket-overwrite it to
|
||||
``interrupted`` in memory or in the store (Copilot review on PR #3381)."""
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
store = MemoryRunStore()
|
||||
rm = RunManager(store=store)
|
||||
record = await rm.create("t-complete")
|
||||
await rm.set_status(record.run_id, RunStatus.running)
|
||||
|
||||
async def worker() -> None:
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
# The run had effectively finished; swallow the cancellation and
|
||||
# record success, like a run that completed in the same tick the
|
||||
# shutdown cancelled it.
|
||||
pass
|
||||
await rm.set_status(record.run_id, RunStatus.success)
|
||||
|
||||
record.task = asyncio.create_task(worker())
|
||||
try:
|
||||
await asyncio.sleep(0) # let the task reach its await point
|
||||
|
||||
await rm.shutdown(timeout=5.0)
|
||||
|
||||
assert record.status == RunStatus.success, f"shutdown overwrote in-memory status: {record.status}"
|
||||
persisted = await store.get(record.run_id)
|
||||
assert persisted is not None and persisted["status"] == "success", f"shutdown overwrote persisted status: {persisted}"
|
||||
finally:
|
||||
if not record.task.done():
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_surfaces_failed_interrupted_persist(caplog):
|
||||
"""A failed interrupted-status persist during the drain must be surfaced (with
|
||||
the run_id), not silently swallowed by the gather (maintainer review on
|
||||
PR #3381)."""
|
||||
import logging
|
||||
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
class _FailingStore(MemoryRunStore):
|
||||
async def update_status(self, *args, **kwargs):
|
||||
raise RuntimeError("store unavailable")
|
||||
|
||||
rm = RunManager(store=_FailingStore())
|
||||
record = await rm.create("t-failpersist")
|
||||
record.status = RunStatus.running # set in memory; the failing store is exercised by the drain
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def worker() -> None:
|
||||
started.set()
|
||||
await asyncio.Event().wait() # blocks until cancelled by the drain
|
||||
|
||||
record.task = asyncio.create_task(worker())
|
||||
try:
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.runtime.runs.manager"):
|
||||
await rm.shutdown(timeout=5.0)
|
||||
assert "Could not persist interrupted status for run" in caplog.text, caplog.text
|
||||
finally:
|
||||
if not record.task.done():
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
@ -32,6 +32,7 @@ class _FakeRunManager:
|
||||
self.store = store
|
||||
self.reconcile_calls: list[dict] = []
|
||||
self.list_by_thread_calls: list[dict] = []
|
||||
self.shutdown_calls: int = 0
|
||||
_FakeRunManager.instances.append(self)
|
||||
|
||||
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
|
||||
@ -42,6 +43,11 @@ class _FakeRunManager:
|
||||
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
|
||||
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
|
||||
|
||||
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
||||
# No in-flight tasks in these startup-recovery tests; langgraph_runtime
|
||||
# drains the manager on teardown, so the double must accept the call.
|
||||
self.shutdown_calls += 1
|
||||
|
||||
|
||||
class _FakeThreadStore:
|
||||
def __init__(self) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user