fix(gateway): honour on_disconnect on /wait endpoints (#3267)

* fix(gateway): honour on_disconnect on /wait endpoints (#3265)

The non-streaming /threads/{tid}/runs/wait and /runs/wait handlers used
to await record.task directly with no disconnect handling and silently
swallow CancelledError. When a long tool call (e.g. pip install inside
a custom skill) kept the connection idle long enough for an
intermediate HTTP layer to time out, the handler would still read the
in-progress checkpoint and return it as if the run had completed
normally -- masking a half-finished run as a successful response.

Add wait_for_run_completion in app.gateway.services that mirrors
sse_consumer's bridge-consumption pattern: subscribe to the stream
bridge until END_SENTINEL, poll request.is_disconnected on every
wake-up, and on real client disconnect cancel the background run when
record.on_disconnect is "cancel". Wire it into both wait endpoints.

The streaming path was unaffected because sse_consumer already has
this loop; this just brings /wait to parity.

* fix(gateway): skip checkpoint serialization on /wait disconnect

Copilot review on #3267 caught a follow-on of the same #3265 bug: when
the client disconnects, wait_for_run_completion breaks out of the bridge
loop and cancels the run, but the /wait endpoint then continues to read
the checkpointer and serializes whatever partial checkpoint exists as a
normal 200 response.

Have the helper return a bool — True only when END_SENTINEL was observed
— and skip the checkpoint serialization path on False. Also reorder the
inner check so END_SENTINEL is honoured even when is_disconnected() flips
true in the same iteration; the run truly finished so the real final
checkpoint is still valid.
This commit is contained in:
AochenShen99 2026-05-28 07:22:39 +08:00 committed by GitHub
parent 9e332c594a
commit a5599c100c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 258 additions and 31 deletions

View File

@ -277,6 +277,7 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs.
- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions.
- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task.
- `POST /wait` (both thread-scoped and `/api/runs/wait`) drains the stream bridge via `wait_for_run_completion()` instead of bare `await record.task`, so it honours the run's `on_disconnect` setting and cancels the background run on real client disconnect rather than returning a stale checkpoint (issue #3265).
Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.

View File

@ -7,7 +7,6 @@ is reused so that conversation history is preserved across calls.
from __future__ import annotations
import asyncio
import logging
import uuid
@ -17,7 +16,7 @@ from fastapi.responses import StreamingResponse
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
@ -66,24 +65,25 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
completed = True
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
if completed:
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}

View File

@ -21,7 +21,7 @@ from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
logger = logging.getLogger(__name__)
@ -175,24 +175,25 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
completed = True
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
if completed:
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}

View File

@ -402,3 +402,51 @@ async def sse_consumer(
if record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)
async def wait_for_run_completion(
bridge: StreamBridge,
record: RunRecord,
request: Request,
run_mgr: RunManager,
) -> bool:
"""Block until the run publishes ``END_SENTINEL``, honouring on_disconnect.
The non-streaming ``/wait`` endpoints used to ``await record.task``
directly with no disconnect handling. When the client (or an
intermediate HTTP proxy) timed out during a long tool call such as
``pip install``, the handler would swallow ``CancelledError`` and
serialize whatever checkpoint happened to exist masking a half-finished
run as a normal completion (issue #3265).
This helper consumes the same bridge that ``sse_consumer`` does so the
wait path shares its disconnect semantics: each wake-up polls
``request.is_disconnected()``; on a real disconnect it cancels the
background run when ``record.on_disconnect`` is ``cancel``. The bridge's
heartbeat sentinels guarantee at least one wake-up per
``heartbeat_interval`` even when the agent emits no events for a while.
Returns:
``True`` when ``END_SENTINEL`` was observed (run reached a terminal
state), ``False`` when the loop exited because the client
disconnected. Callers must skip checkpoint serialization on
``False`` so a partial checkpoint is not returned as a normal
response.
"""
completed = False
try:
async for entry in bridge.subscribe(record.run_id):
# END_SENTINEL means the run reached a terminal state; honour it
# even if the client just disconnected so the caller still serializes
# the real final checkpoint.
if entry is END_SENTINEL:
completed = True
return True
if await request.is_disconnected():
break
# Heartbeats and regular events: keep waiting for END_SENTINEL.
return completed
finally:
if not completed and record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)

View File

@ -0,0 +1,177 @@
"""Regression tests for issue #3265.
The non-streaming ``/wait`` endpoints used to ``await record.task`` with no
disconnect handling and silently swallow ``CancelledError``. When a long
tool call (e.g. ``pip install`` inside a custom skill) kept the connection
idle long enough for an intermediate HTTP layer to time out, the handler
would return a stale checkpoint that looked like a normal completion.
The fix introduces ``wait_for_run_completion`` in ``app.gateway.services``:
it subscribes to the stream bridge until ``END_SENTINEL``, polls
``request.is_disconnected()`` on every wake-up, and honours the record's
``on_disconnect`` mode by cancelling the background run on real client
disconnect.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any
from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.schemas import DisconnectMode
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
THREAD_ID = "thread-wait-3265"
@dataclass
class _FakeRequest:
"""Minimal stand-in for FastAPI ``Request`` with controllable disconnect.
``is_disconnected`` is awaited each iteration of the helper's loop, so the
counter lets a test transition from "still connected" to "disconnected"
after N polls without racing the event loop.
"""
disconnect_after: int = 10**9 # effectively "never" by default
_polls: int = 0
async def is_disconnected(self) -> bool:
self._polls += 1
return self._polls > self.disconnect_after
async def _create_running_record(mgr: RunManager, *, on_disconnect: DisconnectMode) -> Any:
record = await mgr.create_or_reject(
THREAD_ID,
assistant_id=None,
on_disconnect=on_disconnect,
)
await mgr.set_status(record.run_id, RunStatus.running)
return record
# ---------------------------------------------------------------------------
# Helper-level unit tests
# ---------------------------------------------------------------------------
class TestWaitForRunCompletion:
def test_returns_when_run_publishes_end(self) -> None:
"""Happy path: helper returns once the bridge publishes END_SENTINEL."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
request = _FakeRequest()
async def finish_soon() -> None:
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"messages": []})
await mgr.set_status(record.run_id, RunStatus.success)
await bridge.publish_end(record.run_id)
asyncio.create_task(finish_soon())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is True
assert record.status == RunStatus.success
asyncio.run(run())
def test_cancels_run_on_disconnect_when_cancel_mode(self) -> None:
"""on_disconnect=cancel: real disconnect must call run_mgr.cancel()."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
# Attach a real (idle) task so cancel() actually has something to cancel.
sleeper = asyncio.create_task(asyncio.sleep(30))
record.task = sleeper
request = _FakeRequest(disconnect_after=0) # disconnected on first poll
async def publish_until_cancel() -> None:
# Emit one event so subscribe wakes up immediately; helper polls
# is_disconnected after each yield.
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"step": 1})
asyncio.create_task(publish_until_cancel())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is False
assert record.status == RunStatus.interrupted
# Drain the cancelled sleeper so it does not linger past the test.
try:
await asyncio.wait_for(sleeper, timeout=1.0)
except asyncio.CancelledError:
pass
assert sleeper.done()
asyncio.run(run())
def test_does_not_cancel_when_continue_mode(self) -> None:
"""on_disconnect=continue: disconnect must NOT cancel the run."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.continue_)
sleeper = asyncio.create_task(asyncio.sleep(30))
record.task = sleeper
request = _FakeRequest(disconnect_after=0)
async def publish_then_end() -> None:
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"step": 1})
asyncio.create_task(publish_then_end())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
# Disconnected before END — helper still reports incomplete so the
# caller skips checkpoint serialization, but the run keeps going.
assert completed is False
assert record.status == RunStatus.running
sleeper.cancel()
asyncio.run(run())
def test_no_cancel_when_run_already_finished(self) -> None:
"""If the run ended (END_SENTINEL) before disconnect is observed, the
finally block must not call cancel the run is already terminal."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
# Publish END before subscribe — helper should see ended=True first
# poll and return without ever observing the "disconnect".
await mgr.set_status(record.run_id, RunStatus.success)
await bridge.publish_end(record.run_id)
request = _FakeRequest(disconnect_after=0)
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is True
assert record.status == RunStatus.success
asyncio.run(run())