deer-flow/backend/tests/test_wait_disconnect_handling.py
AochenShen99 a5599c100c
fix(gateway): honour on_disconnect on /wait endpoints (#3267)
* fix(gateway): honour on_disconnect on /wait endpoints (#3265)

The non-streaming /threads/{tid}/runs/wait and /runs/wait handlers used
to await record.task directly with no disconnect handling and silently
swallow CancelledError. When a long tool call (e.g. pip install inside
a custom skill) kept the connection idle long enough for an
intermediate HTTP layer to time out, the handler would still read the
in-progress checkpoint and return it as if the run had completed
normally -- masking a half-finished run as a successful response.

Add wait_for_run_completion in app.gateway.services that mirrors
sse_consumer's bridge-consumption pattern: subscribe to the stream
bridge until END_SENTINEL, poll request.is_disconnected on every
wake-up, and on real client disconnect cancel the background run when
record.on_disconnect is "cancel". Wire it into both wait endpoints.

The streaming path was unaffected because sse_consumer already has
this loop; this just brings /wait to parity.

* fix(gateway): skip checkpoint serialization on /wait disconnect

Copilot review on #3267 caught a follow-on of the same #3265 bug: when
the client disconnects, wait_for_run_completion breaks out of the bridge
loop and cancels the run, but the /wait endpoint then continues to read
the checkpointer and serializes whatever partial checkpoint exists as a
normal 200 response.

Have the helper return a bool — True only when END_SENTINEL was observed
— and skip the checkpoint serialization path on False. Also reorder the
inner check so END_SENTINEL is honoured even when is_disconnected() flips
true in the same iteration; the run truly finished so the real final
checkpoint is still valid.
2026-05-28 07:22:39 +08:00

178 lines
6.9 KiB
Python

"""Regression tests for issue #3265.
The non-streaming ``/wait`` endpoints used to ``await record.task`` with no
disconnect handling and silently swallow ``CancelledError``. When a long
tool call (e.g. ``pip install`` inside a custom skill) kept the connection
idle long enough for an intermediate HTTP layer to time out, the handler
would return a stale checkpoint that looked like a normal completion.
The fix introduces ``wait_for_run_completion`` in ``app.gateway.services``:
it subscribes to the stream bridge until ``END_SENTINEL``, polls
``request.is_disconnected()`` on every wake-up, and honours the record's
``on_disconnect`` mode by cancelling the background run on real client
disconnect.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any
from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.schemas import DisconnectMode
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
THREAD_ID = "thread-wait-3265"
@dataclass
class _FakeRequest:
"""Minimal stand-in for FastAPI ``Request`` with controllable disconnect.
``is_disconnected`` is awaited each iteration of the helper's loop, so the
counter lets a test transition from "still connected" to "disconnected"
after N polls without racing the event loop.
"""
disconnect_after: int = 10**9 # effectively "never" by default
_polls: int = 0
async def is_disconnected(self) -> bool:
self._polls += 1
return self._polls > self.disconnect_after
async def _create_running_record(mgr: RunManager, *, on_disconnect: DisconnectMode) -> Any:
record = await mgr.create_or_reject(
THREAD_ID,
assistant_id=None,
on_disconnect=on_disconnect,
)
await mgr.set_status(record.run_id, RunStatus.running)
return record
# ---------------------------------------------------------------------------
# Helper-level unit tests
# ---------------------------------------------------------------------------
class TestWaitForRunCompletion:
def test_returns_when_run_publishes_end(self) -> None:
"""Happy path: helper returns once the bridge publishes END_SENTINEL."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
request = _FakeRequest()
async def finish_soon() -> None:
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"messages": []})
await mgr.set_status(record.run_id, RunStatus.success)
await bridge.publish_end(record.run_id)
asyncio.create_task(finish_soon())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is True
assert record.status == RunStatus.success
asyncio.run(run())
def test_cancels_run_on_disconnect_when_cancel_mode(self) -> None:
"""on_disconnect=cancel: real disconnect must call run_mgr.cancel()."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
# Attach a real (idle) task so cancel() actually has something to cancel.
sleeper = asyncio.create_task(asyncio.sleep(30))
record.task = sleeper
request = _FakeRequest(disconnect_after=0) # disconnected on first poll
async def publish_until_cancel() -> None:
# Emit one event so subscribe wakes up immediately; helper polls
# is_disconnected after each yield.
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"step": 1})
asyncio.create_task(publish_until_cancel())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is False
assert record.status == RunStatus.interrupted
# Drain the cancelled sleeper so it does not linger past the test.
try:
await asyncio.wait_for(sleeper, timeout=1.0)
except asyncio.CancelledError:
pass
assert sleeper.done()
asyncio.run(run())
def test_does_not_cancel_when_continue_mode(self) -> None:
"""on_disconnect=continue: disconnect must NOT cancel the run."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.continue_)
sleeper = asyncio.create_task(asyncio.sleep(30))
record.task = sleeper
request = _FakeRequest(disconnect_after=0)
async def publish_then_end() -> None:
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"step": 1})
asyncio.create_task(publish_then_end())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
# Disconnected before END — helper still reports incomplete so the
# caller skips checkpoint serialization, but the run keeps going.
assert completed is False
assert record.status == RunStatus.running
sleeper.cancel()
asyncio.run(run())
def test_no_cancel_when_run_already_finished(self) -> None:
"""If the run ended (END_SENTINEL) before disconnect is observed, the
finally block must not call cancel — the run is already terminal."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
# Publish END before subscribe — helper should see ended=True first
# poll and return without ever observing the "disconnect".
await mgr.set_status(record.run_id, RunStatus.success)
await bridge.publish_end(record.run_id)
request = _FakeRequest(disconnect_after=0)
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is True
assert record.status == RunStatus.success
asyncio.run(run())