From a5599c100cd74989fe919bfb78056b301dd05189 Mon Sep 17 00:00:00 2001 From: AochenShen99 Date: Thu, 28 May 2026 07:22:39 +0800 Subject: [PATCH] fix(gateway): honour on_disconnect on /wait endpoints (#3267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(gateway): honour on_disconnect on /wait endpoints (#3265) The non-streaming /threads/{tid}/runs/wait and /runs/wait handlers used to await record.task directly with no disconnect handling and silently swallow CancelledError. When a long tool call (e.g. pip install inside a custom skill) kept the connection idle long enough for an intermediate HTTP layer to time out, the handler would still read the in-progress checkpoint and return it as if the run had completed normally -- masking a half-finished run as a successful response. Add wait_for_run_completion in app.gateway.services that mirrors sse_consumer's bridge-consumption pattern: subscribe to the stream bridge until END_SENTINEL, poll request.is_disconnected on every wake-up, and on real client disconnect cancel the background run when record.on_disconnect is "cancel". Wire it into both wait endpoints. The streaming path was unaffected because sse_consumer already has this loop; this just brings /wait to parity. * fix(gateway): skip checkpoint serialization on /wait disconnect Copilot review on #3267 caught a follow-on of the same #3265 bug: when the client disconnects, wait_for_run_completion breaks out of the bridge loop and cancels the run, but the /wait endpoint then continues to read the checkpointer and serializes whatever partial checkpoint exists as a normal 200 response. Have the helper return a bool — True only when END_SENTINEL was observed — and skip the checkpoint serialization path on False. Also reorder the inner check so END_SENTINEL is honoured even when is_disconnected() flips true in the same iteration; the run truly finished so the real final checkpoint is still valid. --- backend/CLAUDE.md | 1 + backend/app/gateway/routers/runs.py | 32 ++-- backend/app/gateway/routers/thread_runs.py | 31 +-- backend/app/gateway/services.py | 48 +++++ .../tests/test_wait_disconnect_handling.py | 177 ++++++++++++++++++ 5 files changed, 258 insertions(+), 31 deletions(-) create mode 100644 backend/tests/test_wait_disconnect_handling.py diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 38e8e1d26..b655b2225 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -277,6 +277,7 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S - When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs. - `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions. - Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task. +- `POST /wait` (both thread-scoped and `/api/runs/wait`) drains the stream bridge via `wait_for_run_completion()` instead of bare `await record.task`, so it honours the run's `on_disconnect` setting and cancels the background run on real client disconnect rather than returning a stale checkpoint (issue #3265). Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs. diff --git a/backend/app/gateway/routers/runs.py b/backend/app/gateway/routers/runs.py index f2775466c..1e61ffd25 100644 --- a/backend/app/gateway/routers/runs.py +++ b/backend/app/gateway/routers/runs.py @@ -7,7 +7,6 @@ is reused so that conversation history is preserved across calls. from __future__ import annotations -import asyncio import logging import uuid @@ -17,7 +16,7 @@ from fastapi.responses import StreamingResponse from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.routers.thread_runs import RunCreateRequest -from app.gateway.services import sse_consumer, start_run +from app.gateway.services import sse_consumer, start_run, wait_for_run_completion from deerflow.runtime import serialize_channel_values logger = logging.getLogger(__name__) @@ -66,24 +65,25 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict: Otherwise a new temporary thread is created. """ thread_id = _resolve_thread_id(body) + bridge = get_stream_bridge(request) + run_mgr = get_run_manager(request) record = await start_run(body, thread_id, request) + completed = True if record.task is not None: - try: - await record.task - except asyncio.CancelledError: - pass + completed = await wait_for_run_completion(bridge, record, request, run_mgr) - checkpointer = get_checkpointer(request) - config = {"configurable": {"thread_id": thread_id}} - try: - checkpoint_tuple = await checkpointer.aget_tuple(config) - if checkpoint_tuple is not None: - checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} - channel_values = checkpoint.get("channel_values", {}) - return serialize_channel_values(channel_values) - except Exception: - logger.exception("Failed to fetch final state for run %s", record.run_id) + if completed: + checkpointer = get_checkpointer(request) + config = {"configurable": {"thread_id": thread_id}} + try: + checkpoint_tuple = await checkpointer.aget_tuple(config) + if checkpoint_tuple is not None: + checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} + channel_values = checkpoint.get("channel_values", {}) + return serialize_channel_values(channel_values) + except Exception: + logger.exception("Failed to fetch final state for run %s", record.run_id) return {"status": record.status.value, "error": record.error} diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index a542593b2..8a7cade4d 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -21,7 +21,7 @@ from pydantic import BaseModel, Field from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge -from app.gateway.services import sse_consumer, start_run +from app.gateway.services import sse_consumer, start_run, wait_for_run_completion from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values logger = logging.getLogger(__name__) @@ -175,24 +175,25 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) - @require_permission("runs", "create", owner_check=True, require_existing=True) async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict: """Create a run and block until it completes, returning the final state.""" + bridge = get_stream_bridge(request) + run_mgr = get_run_manager(request) record = await start_run(body, thread_id, request) + completed = True if record.task is not None: - try: - await record.task - except asyncio.CancelledError: - pass + completed = await wait_for_run_completion(bridge, record, request, run_mgr) - checkpointer = get_checkpointer(request) - config = {"configurable": {"thread_id": thread_id}} - try: - checkpoint_tuple = await checkpointer.aget_tuple(config) - if checkpoint_tuple is not None: - checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} - channel_values = checkpoint.get("channel_values", {}) - return serialize_channel_values(channel_values) - except Exception: - logger.exception("Failed to fetch final state for run %s", record.run_id) + if completed: + checkpointer = get_checkpointer(request) + config = {"configurable": {"thread_id": thread_id}} + try: + checkpoint_tuple = await checkpointer.aget_tuple(config) + if checkpoint_tuple is not None: + checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} + channel_values = checkpoint.get("channel_values", {}) + return serialize_channel_values(channel_values) + except Exception: + logger.exception("Failed to fetch final state for run %s", record.run_id) return {"status": record.status.value, "error": record.error} diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 95e26144a..63ac0c1bf 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -402,3 +402,51 @@ async def sse_consumer( if record.status in (RunStatus.pending, RunStatus.running): if record.on_disconnect == DisconnectMode.cancel: await run_mgr.cancel(record.run_id) + + +async def wait_for_run_completion( + bridge: StreamBridge, + record: RunRecord, + request: Request, + run_mgr: RunManager, +) -> bool: + """Block until the run publishes ``END_SENTINEL``, honouring on_disconnect. + + The non-streaming ``/wait`` endpoints used to ``await record.task`` + directly with no disconnect handling. When the client (or an + intermediate HTTP proxy) timed out during a long tool call such as + ``pip install``, the handler would swallow ``CancelledError`` and + serialize whatever checkpoint happened to exist — masking a half-finished + run as a normal completion (issue #3265). + + This helper consumes the same bridge that ``sse_consumer`` does so the + wait path shares its disconnect semantics: each wake-up polls + ``request.is_disconnected()``; on a real disconnect it cancels the + background run when ``record.on_disconnect`` is ``cancel``. The bridge's + heartbeat sentinels guarantee at least one wake-up per + ``heartbeat_interval`` even when the agent emits no events for a while. + + Returns: + ``True`` when ``END_SENTINEL`` was observed (run reached a terminal + state), ``False`` when the loop exited because the client + disconnected. Callers must skip checkpoint serialization on + ``False`` so a partial checkpoint is not returned as a normal + response. + """ + completed = False + try: + async for entry in bridge.subscribe(record.run_id): + # END_SENTINEL means the run reached a terminal state; honour it + # even if the client just disconnected so the caller still serializes + # the real final checkpoint. + if entry is END_SENTINEL: + completed = True + return True + if await request.is_disconnected(): + break + # Heartbeats and regular events: keep waiting for END_SENTINEL. + return completed + finally: + if not completed and record.status in (RunStatus.pending, RunStatus.running): + if record.on_disconnect == DisconnectMode.cancel: + await run_mgr.cancel(record.run_id) diff --git a/backend/tests/test_wait_disconnect_handling.py b/backend/tests/test_wait_disconnect_handling.py new file mode 100644 index 000000000..62eec8ecf --- /dev/null +++ b/backend/tests/test_wait_disconnect_handling.py @@ -0,0 +1,177 @@ +"""Regression tests for issue #3265. + +The non-streaming ``/wait`` endpoints used to ``await record.task`` with no +disconnect handling and silently swallow ``CancelledError``. When a long +tool call (e.g. ``pip install`` inside a custom skill) kept the connection +idle long enough for an intermediate HTTP layer to time out, the handler +would return a stale checkpoint that looked like a normal completion. + +The fix introduces ``wait_for_run_completion`` in ``app.gateway.services``: +it subscribes to the stream bridge until ``END_SENTINEL``, polls +``request.is_disconnected()`` on every wake-up, and honours the record's +``on_disconnect`` mode by cancelling the background run on real client +disconnect. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Any + +from deerflow.runtime import RunManager, RunStatus +from deerflow.runtime.runs.schemas import DisconnectMode +from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge + +THREAD_ID = "thread-wait-3265" + + +@dataclass +class _FakeRequest: + """Minimal stand-in for FastAPI ``Request`` with controllable disconnect. + + ``is_disconnected`` is awaited each iteration of the helper's loop, so the + counter lets a test transition from "still connected" to "disconnected" + after N polls without racing the event loop. + """ + + disconnect_after: int = 10**9 # effectively "never" by default + _polls: int = 0 + + async def is_disconnected(self) -> bool: + self._polls += 1 + return self._polls > self.disconnect_after + + +async def _create_running_record(mgr: RunManager, *, on_disconnect: DisconnectMode) -> Any: + record = await mgr.create_or_reject( + THREAD_ID, + assistant_id=None, + on_disconnect=on_disconnect, + ) + await mgr.set_status(record.run_id, RunStatus.running) + return record + + +# --------------------------------------------------------------------------- +# Helper-level unit tests +# --------------------------------------------------------------------------- + + +class TestWaitForRunCompletion: + def test_returns_when_run_publishes_end(self) -> None: + """Happy path: helper returns once the bridge publishes END_SENTINEL.""" + from app.gateway.services import wait_for_run_completion + + async def run() -> None: + mgr = RunManager() + bridge = MemoryStreamBridge() + record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel) + request = _FakeRequest() + + async def finish_soon() -> None: + await asyncio.sleep(0) + await bridge.publish(record.run_id, "values", {"messages": []}) + await mgr.set_status(record.run_id, RunStatus.success) + await bridge.publish_end(record.run_id) + + asyncio.create_task(finish_soon()) + completed = await asyncio.wait_for( + wait_for_run_completion(bridge, record, request, mgr), + timeout=2.0, + ) + assert completed is True + assert record.status == RunStatus.success + + asyncio.run(run()) + + def test_cancels_run_on_disconnect_when_cancel_mode(self) -> None: + """on_disconnect=cancel: real disconnect must call run_mgr.cancel().""" + from app.gateway.services import wait_for_run_completion + + async def run() -> None: + mgr = RunManager() + bridge = MemoryStreamBridge() + record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel) + # Attach a real (idle) task so cancel() actually has something to cancel. + sleeper = asyncio.create_task(asyncio.sleep(30)) + record.task = sleeper + request = _FakeRequest(disconnect_after=0) # disconnected on first poll + + async def publish_until_cancel() -> None: + # Emit one event so subscribe wakes up immediately; helper polls + # is_disconnected after each yield. + await asyncio.sleep(0) + await bridge.publish(record.run_id, "values", {"step": 1}) + + asyncio.create_task(publish_until_cancel()) + completed = await asyncio.wait_for( + wait_for_run_completion(bridge, record, request, mgr), + timeout=2.0, + ) + + assert completed is False + assert record.status == RunStatus.interrupted + # Drain the cancelled sleeper so it does not linger past the test. + try: + await asyncio.wait_for(sleeper, timeout=1.0) + except asyncio.CancelledError: + pass + assert sleeper.done() + + asyncio.run(run()) + + def test_does_not_cancel_when_continue_mode(self) -> None: + """on_disconnect=continue: disconnect must NOT cancel the run.""" + from app.gateway.services import wait_for_run_completion + + async def run() -> None: + mgr = RunManager() + bridge = MemoryStreamBridge() + record = await _create_running_record(mgr, on_disconnect=DisconnectMode.continue_) + sleeper = asyncio.create_task(asyncio.sleep(30)) + record.task = sleeper + request = _FakeRequest(disconnect_after=0) + + async def publish_then_end() -> None: + await asyncio.sleep(0) + await bridge.publish(record.run_id, "values", {"step": 1}) + + asyncio.create_task(publish_then_end()) + completed = await asyncio.wait_for( + wait_for_run_completion(bridge, record, request, mgr), + timeout=2.0, + ) + + # Disconnected before END — helper still reports incomplete so the + # caller skips checkpoint serialization, but the run keeps going. + assert completed is False + assert record.status == RunStatus.running + sleeper.cancel() + + asyncio.run(run()) + + def test_no_cancel_when_run_already_finished(self) -> None: + """If the run ended (END_SENTINEL) before disconnect is observed, the + finally block must not call cancel — the run is already terminal.""" + from app.gateway.services import wait_for_run_completion + + async def run() -> None: + mgr = RunManager() + bridge = MemoryStreamBridge() + record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel) + # Publish END before subscribe — helper should see ended=True first + # poll and return without ever observing the "disconnect". + await mgr.set_status(record.run_id, RunStatus.success) + await bridge.publish_end(record.run_id) + request = _FakeRequest(disconnect_after=0) + + completed = await asyncio.wait_for( + wait_for_run_completion(bridge, record, request, mgr), + timeout=2.0, + ) + + assert completed is True + assert record.status == RunStatus.success + + asyncio.run(run())