mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-30 12:28:10 +00:00
* fix(gateway): honour on_disconnect on /wait endpoints (#3265) The non-streaming /threads/{tid}/runs/wait and /runs/wait handlers used to await record.task directly with no disconnect handling and silently swallow CancelledError. When a long tool call (e.g. pip install inside a custom skill) kept the connection idle long enough for an intermediate HTTP layer to time out, the handler would still read the in-progress checkpoint and return it as if the run had completed normally -- masking a half-finished run as a successful response. Add wait_for_run_completion in app.gateway.services that mirrors sse_consumer's bridge-consumption pattern: subscribe to the stream bridge until END_SENTINEL, poll request.is_disconnected on every wake-up, and on real client disconnect cancel the background run when record.on_disconnect is "cancel". Wire it into both wait endpoints. The streaming path was unaffected because sse_consumer already has this loop; this just brings /wait to parity. * fix(gateway): skip checkpoint serialization on /wait disconnect Copilot review on #3267 caught a follow-on of the same #3265 bug: when the client disconnects, wait_for_run_completion breaks out of the bridge loop and cancels the run, but the /wait endpoint then continues to read the checkpointer and serializes whatever partial checkpoint exists as a normal 200 response. Have the helper return a bool — True only when END_SENTINEL was observed — and skip the checkpoint serialization path on False. Also reorder the inner check so END_SENTINEL is honoured even when is_disconnected() flips true in the same iteration; the run truly finished so the real final checkpoint is still valid.
178 lines
6.9 KiB
Python
178 lines
6.9 KiB
Python
"""Regression tests for issue #3265.
|
|
|
|
The non-streaming ``/wait`` endpoints used to ``await record.task`` with no
|
|
disconnect handling and silently swallow ``CancelledError``. When a long
|
|
tool call (e.g. ``pip install`` inside a custom skill) kept the connection
|
|
idle long enough for an intermediate HTTP layer to time out, the handler
|
|
would return a stale checkpoint that looked like a normal completion.
|
|
|
|
The fix introduces ``wait_for_run_completion`` in ``app.gateway.services``:
|
|
it subscribes to the stream bridge until ``END_SENTINEL``, polls
|
|
``request.is_disconnected()`` on every wake-up, and honours the record's
|
|
``on_disconnect`` mode by cancelling the background run on real client
|
|
disconnect.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from deerflow.runtime import RunManager, RunStatus
|
|
from deerflow.runtime.runs.schemas import DisconnectMode
|
|
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
|
|
|
|
THREAD_ID = "thread-wait-3265"
|
|
|
|
|
|
@dataclass
|
|
class _FakeRequest:
|
|
"""Minimal stand-in for FastAPI ``Request`` with controllable disconnect.
|
|
|
|
``is_disconnected`` is awaited each iteration of the helper's loop, so the
|
|
counter lets a test transition from "still connected" to "disconnected"
|
|
after N polls without racing the event loop.
|
|
"""
|
|
|
|
disconnect_after: int = 10**9 # effectively "never" by default
|
|
_polls: int = 0
|
|
|
|
async def is_disconnected(self) -> bool:
|
|
self._polls += 1
|
|
return self._polls > self.disconnect_after
|
|
|
|
|
|
async def _create_running_record(mgr: RunManager, *, on_disconnect: DisconnectMode) -> Any:
|
|
record = await mgr.create_or_reject(
|
|
THREAD_ID,
|
|
assistant_id=None,
|
|
on_disconnect=on_disconnect,
|
|
)
|
|
await mgr.set_status(record.run_id, RunStatus.running)
|
|
return record
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper-level unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWaitForRunCompletion:
|
|
def test_returns_when_run_publishes_end(self) -> None:
|
|
"""Happy path: helper returns once the bridge publishes END_SENTINEL."""
|
|
from app.gateway.services import wait_for_run_completion
|
|
|
|
async def run() -> None:
|
|
mgr = RunManager()
|
|
bridge = MemoryStreamBridge()
|
|
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
|
request = _FakeRequest()
|
|
|
|
async def finish_soon() -> None:
|
|
await asyncio.sleep(0)
|
|
await bridge.publish(record.run_id, "values", {"messages": []})
|
|
await mgr.set_status(record.run_id, RunStatus.success)
|
|
await bridge.publish_end(record.run_id)
|
|
|
|
asyncio.create_task(finish_soon())
|
|
completed = await asyncio.wait_for(
|
|
wait_for_run_completion(bridge, record, request, mgr),
|
|
timeout=2.0,
|
|
)
|
|
assert completed is True
|
|
assert record.status == RunStatus.success
|
|
|
|
asyncio.run(run())
|
|
|
|
def test_cancels_run_on_disconnect_when_cancel_mode(self) -> None:
|
|
"""on_disconnect=cancel: real disconnect must call run_mgr.cancel()."""
|
|
from app.gateway.services import wait_for_run_completion
|
|
|
|
async def run() -> None:
|
|
mgr = RunManager()
|
|
bridge = MemoryStreamBridge()
|
|
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
|
# Attach a real (idle) task so cancel() actually has something to cancel.
|
|
sleeper = asyncio.create_task(asyncio.sleep(30))
|
|
record.task = sleeper
|
|
request = _FakeRequest(disconnect_after=0) # disconnected on first poll
|
|
|
|
async def publish_until_cancel() -> None:
|
|
# Emit one event so subscribe wakes up immediately; helper polls
|
|
# is_disconnected after each yield.
|
|
await asyncio.sleep(0)
|
|
await bridge.publish(record.run_id, "values", {"step": 1})
|
|
|
|
asyncio.create_task(publish_until_cancel())
|
|
completed = await asyncio.wait_for(
|
|
wait_for_run_completion(bridge, record, request, mgr),
|
|
timeout=2.0,
|
|
)
|
|
|
|
assert completed is False
|
|
assert record.status == RunStatus.interrupted
|
|
# Drain the cancelled sleeper so it does not linger past the test.
|
|
try:
|
|
await asyncio.wait_for(sleeper, timeout=1.0)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
assert sleeper.done()
|
|
|
|
asyncio.run(run())
|
|
|
|
def test_does_not_cancel_when_continue_mode(self) -> None:
|
|
"""on_disconnect=continue: disconnect must NOT cancel the run."""
|
|
from app.gateway.services import wait_for_run_completion
|
|
|
|
async def run() -> None:
|
|
mgr = RunManager()
|
|
bridge = MemoryStreamBridge()
|
|
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.continue_)
|
|
sleeper = asyncio.create_task(asyncio.sleep(30))
|
|
record.task = sleeper
|
|
request = _FakeRequest(disconnect_after=0)
|
|
|
|
async def publish_then_end() -> None:
|
|
await asyncio.sleep(0)
|
|
await bridge.publish(record.run_id, "values", {"step": 1})
|
|
|
|
asyncio.create_task(publish_then_end())
|
|
completed = await asyncio.wait_for(
|
|
wait_for_run_completion(bridge, record, request, mgr),
|
|
timeout=2.0,
|
|
)
|
|
|
|
# Disconnected before END — helper still reports incomplete so the
|
|
# caller skips checkpoint serialization, but the run keeps going.
|
|
assert completed is False
|
|
assert record.status == RunStatus.running
|
|
sleeper.cancel()
|
|
|
|
asyncio.run(run())
|
|
|
|
def test_no_cancel_when_run_already_finished(self) -> None:
|
|
"""If the run ended (END_SENTINEL) before disconnect is observed, the
|
|
finally block must not call cancel — the run is already terminal."""
|
|
from app.gateway.services import wait_for_run_completion
|
|
|
|
async def run() -> None:
|
|
mgr = RunManager()
|
|
bridge = MemoryStreamBridge()
|
|
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
|
# Publish END before subscribe — helper should see ended=True first
|
|
# poll and return without ever observing the "disconnect".
|
|
await mgr.set_status(record.run_id, RunStatus.success)
|
|
await bridge.publish_end(record.run_id)
|
|
request = _FakeRequest(disconnect_after=0)
|
|
|
|
completed = await asyncio.wait_for(
|
|
wait_for_run_completion(bridge, record, request, mgr),
|
|
timeout=2.0,
|
|
)
|
|
|
|
assert completed is True
|
|
assert record.status == RunStatus.success
|
|
|
|
asyncio.run(run())
|