mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-30 12:28:10 +00:00
fix(gateway): honour on_disconnect on /wait endpoints (#3267)
* fix(gateway): honour on_disconnect on /wait endpoints (#3265) The non-streaming /threads/{tid}/runs/wait and /runs/wait handlers used to await record.task directly with no disconnect handling and silently swallow CancelledError. When a long tool call (e.g. pip install inside a custom skill) kept the connection idle long enough for an intermediate HTTP layer to time out, the handler would still read the in-progress checkpoint and return it as if the run had completed normally -- masking a half-finished run as a successful response. Add wait_for_run_completion in app.gateway.services that mirrors sse_consumer's bridge-consumption pattern: subscribe to the stream bridge until END_SENTINEL, poll request.is_disconnected on every wake-up, and on real client disconnect cancel the background run when record.on_disconnect is "cancel". Wire it into both wait endpoints. The streaming path was unaffected because sse_consumer already has this loop; this just brings /wait to parity. * fix(gateway): skip checkpoint serialization on /wait disconnect Copilot review on #3267 caught a follow-on of the same #3265 bug: when the client disconnects, wait_for_run_completion breaks out of the bridge loop and cancels the run, but the /wait endpoint then continues to read the checkpointer and serializes whatever partial checkpoint exists as a normal 200 response. Have the helper return a bool — True only when END_SENTINEL was observed — and skip the checkpoint serialization path on False. Also reorder the inner check so END_SENTINEL is honoured even when is_disconnected() flips true in the same iteration; the run truly finished so the real final checkpoint is still valid.
This commit is contained in:
parent
9e332c594a
commit
a5599c100c
@ -277,6 +277,7 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
|
||||
- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs.
|
||||
- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions.
|
||||
- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task.
|
||||
- `POST /wait` (both thread-scoped and `/api/runs/wait`) drains the stream bridge via `wait_for_run_completion()` instead of bare `await record.task`, so it honours the run's `on_disconnect` setting and cancels the background run on real client disconnect rather than returning a stale checkpoint (issue #3265).
|
||||
|
||||
Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ is reused so that conversation history is preserved across calls.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
@ -17,7 +16,7 @@ from fastapi.responses import StreamingResponse
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -66,24 +65,25 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
||||
Otherwise a new temporary thread is created.
|
||||
"""
|
||||
thread_id = _resolve_thread_id(body)
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
completed = True
|
||||
if record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
if completed:
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -175,24 +175,25 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
|
||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until it completes, returning the final state."""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
completed = True
|
||||
if record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
if completed:
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
@ -402,3 +402,51 @@ async def sse_consumer(
|
||||
if record.status in (RunStatus.pending, RunStatus.running):
|
||||
if record.on_disconnect == DisconnectMode.cancel:
|
||||
await run_mgr.cancel(record.run_id)
|
||||
|
||||
|
||||
async def wait_for_run_completion(
|
||||
bridge: StreamBridge,
|
||||
record: RunRecord,
|
||||
request: Request,
|
||||
run_mgr: RunManager,
|
||||
) -> bool:
|
||||
"""Block until the run publishes ``END_SENTINEL``, honouring on_disconnect.
|
||||
|
||||
The non-streaming ``/wait`` endpoints used to ``await record.task``
|
||||
directly with no disconnect handling. When the client (or an
|
||||
intermediate HTTP proxy) timed out during a long tool call such as
|
||||
``pip install``, the handler would swallow ``CancelledError`` and
|
||||
serialize whatever checkpoint happened to exist — masking a half-finished
|
||||
run as a normal completion (issue #3265).
|
||||
|
||||
This helper consumes the same bridge that ``sse_consumer`` does so the
|
||||
wait path shares its disconnect semantics: each wake-up polls
|
||||
``request.is_disconnected()``; on a real disconnect it cancels the
|
||||
background run when ``record.on_disconnect`` is ``cancel``. The bridge's
|
||||
heartbeat sentinels guarantee at least one wake-up per
|
||||
``heartbeat_interval`` even when the agent emits no events for a while.
|
||||
|
||||
Returns:
|
||||
``True`` when ``END_SENTINEL`` was observed (run reached a terminal
|
||||
state), ``False`` when the loop exited because the client
|
||||
disconnected. Callers must skip checkpoint serialization on
|
||||
``False`` so a partial checkpoint is not returned as a normal
|
||||
response.
|
||||
"""
|
||||
completed = False
|
||||
try:
|
||||
async for entry in bridge.subscribe(record.run_id):
|
||||
# END_SENTINEL means the run reached a terminal state; honour it
|
||||
# even if the client just disconnected so the caller still serializes
|
||||
# the real final checkpoint.
|
||||
if entry is END_SENTINEL:
|
||||
completed = True
|
||||
return True
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
# Heartbeats and regular events: keep waiting for END_SENTINEL.
|
||||
return completed
|
||||
finally:
|
||||
if not completed and record.status in (RunStatus.pending, RunStatus.running):
|
||||
if record.on_disconnect == DisconnectMode.cancel:
|
||||
await run_mgr.cancel(record.run_id)
|
||||
|
||||
177
backend/tests/test_wait_disconnect_handling.py
Normal file
177
backend/tests/test_wait_disconnect_handling.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Regression tests for issue #3265.
|
||||
|
||||
The non-streaming ``/wait`` endpoints used to ``await record.task`` with no
|
||||
disconnect handling and silently swallow ``CancelledError``. When a long
|
||||
tool call (e.g. ``pip install`` inside a custom skill) kept the connection
|
||||
idle long enough for an intermediate HTTP layer to time out, the handler
|
||||
would return a stale checkpoint that looked like a normal completion.
|
||||
|
||||
The fix introduces ``wait_for_run_completion`` in ``app.gateway.services``:
|
||||
it subscribes to the stream bridge until ``END_SENTINEL``, polls
|
||||
``request.is_disconnected()`` on every wake-up, and honours the record's
|
||||
``on_disconnect`` mode by cancelling the background run on real client
|
||||
disconnect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime import RunManager, RunStatus
|
||||
from deerflow.runtime.runs.schemas import DisconnectMode
|
||||
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
|
||||
|
||||
THREAD_ID = "thread-wait-3265"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeRequest:
|
||||
"""Minimal stand-in for FastAPI ``Request`` with controllable disconnect.
|
||||
|
||||
``is_disconnected`` is awaited each iteration of the helper's loop, so the
|
||||
counter lets a test transition from "still connected" to "disconnected"
|
||||
after N polls without racing the event loop.
|
||||
"""
|
||||
|
||||
disconnect_after: int = 10**9 # effectively "never" by default
|
||||
_polls: int = 0
|
||||
|
||||
async def is_disconnected(self) -> bool:
|
||||
self._polls += 1
|
||||
return self._polls > self.disconnect_after
|
||||
|
||||
|
||||
async def _create_running_record(mgr: RunManager, *, on_disconnect: DisconnectMode) -> Any:
|
||||
record = await mgr.create_or_reject(
|
||||
THREAD_ID,
|
||||
assistant_id=None,
|
||||
on_disconnect=on_disconnect,
|
||||
)
|
||||
await mgr.set_status(record.run_id, RunStatus.running)
|
||||
return record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper-level unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWaitForRunCompletion:
|
||||
def test_returns_when_run_publishes_end(self) -> None:
|
||||
"""Happy path: helper returns once the bridge publishes END_SENTINEL."""
|
||||
from app.gateway.services import wait_for_run_completion
|
||||
|
||||
async def run() -> None:
|
||||
mgr = RunManager()
|
||||
bridge = MemoryStreamBridge()
|
||||
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
||||
request = _FakeRequest()
|
||||
|
||||
async def finish_soon() -> None:
|
||||
await asyncio.sleep(0)
|
||||
await bridge.publish(record.run_id, "values", {"messages": []})
|
||||
await mgr.set_status(record.run_id, RunStatus.success)
|
||||
await bridge.publish_end(record.run_id)
|
||||
|
||||
asyncio.create_task(finish_soon())
|
||||
completed = await asyncio.wait_for(
|
||||
wait_for_run_completion(bridge, record, request, mgr),
|
||||
timeout=2.0,
|
||||
)
|
||||
assert completed is True
|
||||
assert record.status == RunStatus.success
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_cancels_run_on_disconnect_when_cancel_mode(self) -> None:
|
||||
"""on_disconnect=cancel: real disconnect must call run_mgr.cancel()."""
|
||||
from app.gateway.services import wait_for_run_completion
|
||||
|
||||
async def run() -> None:
|
||||
mgr = RunManager()
|
||||
bridge = MemoryStreamBridge()
|
||||
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
||||
# Attach a real (idle) task so cancel() actually has something to cancel.
|
||||
sleeper = asyncio.create_task(asyncio.sleep(30))
|
||||
record.task = sleeper
|
||||
request = _FakeRequest(disconnect_after=0) # disconnected on first poll
|
||||
|
||||
async def publish_until_cancel() -> None:
|
||||
# Emit one event so subscribe wakes up immediately; helper polls
|
||||
# is_disconnected after each yield.
|
||||
await asyncio.sleep(0)
|
||||
await bridge.publish(record.run_id, "values", {"step": 1})
|
||||
|
||||
asyncio.create_task(publish_until_cancel())
|
||||
completed = await asyncio.wait_for(
|
||||
wait_for_run_completion(bridge, record, request, mgr),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
assert completed is False
|
||||
assert record.status == RunStatus.interrupted
|
||||
# Drain the cancelled sleeper so it does not linger past the test.
|
||||
try:
|
||||
await asyncio.wait_for(sleeper, timeout=1.0)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
assert sleeper.done()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_does_not_cancel_when_continue_mode(self) -> None:
|
||||
"""on_disconnect=continue: disconnect must NOT cancel the run."""
|
||||
from app.gateway.services import wait_for_run_completion
|
||||
|
||||
async def run() -> None:
|
||||
mgr = RunManager()
|
||||
bridge = MemoryStreamBridge()
|
||||
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.continue_)
|
||||
sleeper = asyncio.create_task(asyncio.sleep(30))
|
||||
record.task = sleeper
|
||||
request = _FakeRequest(disconnect_after=0)
|
||||
|
||||
async def publish_then_end() -> None:
|
||||
await asyncio.sleep(0)
|
||||
await bridge.publish(record.run_id, "values", {"step": 1})
|
||||
|
||||
asyncio.create_task(publish_then_end())
|
||||
completed = await asyncio.wait_for(
|
||||
wait_for_run_completion(bridge, record, request, mgr),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
# Disconnected before END — helper still reports incomplete so the
|
||||
# caller skips checkpoint serialization, but the run keeps going.
|
||||
assert completed is False
|
||||
assert record.status == RunStatus.running
|
||||
sleeper.cancel()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_no_cancel_when_run_already_finished(self) -> None:
|
||||
"""If the run ended (END_SENTINEL) before disconnect is observed, the
|
||||
finally block must not call cancel — the run is already terminal."""
|
||||
from app.gateway.services import wait_for_run_completion
|
||||
|
||||
async def run() -> None:
|
||||
mgr = RunManager()
|
||||
bridge = MemoryStreamBridge()
|
||||
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
|
||||
# Publish END before subscribe — helper should see ended=True first
|
||||
# poll and return without ever observing the "disconnect".
|
||||
await mgr.set_status(record.run_id, RunStatus.success)
|
||||
await bridge.publish_end(record.run_id)
|
||||
request = _FakeRequest(disconnect_after=0)
|
||||
|
||||
completed = await asyncio.wait_for(
|
||||
wait_for_run_completion(bridge, record, request, mgr),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
assert completed is True
|
||||
assert record.status == RunStatus.success
|
||||
|
||||
asyncio.run(run())
|
||||
Loading…
x
Reference in New Issue
Block a user