Merge pull request #600 from voidborne-d/fix/send-message-sync-cross-loop

fix: use run_coroutine_threadsafe in send_message_sync to prevent cross-loop RuntimeError
This commit is contained in:
Shu Yao 2026-04-04 18:23:24 +08:00 committed by GitHub
commit 117e4d8d0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 267 additions and 7 deletions

View File

@ -49,6 +49,7 @@ class WebSocketManager:
):
self.active_connections: Dict[str, WebSocket] = {}
self.connection_timestamps: Dict[str, float] = {}
self._owner_loop: Optional[asyncio.AbstractEventLoop] = None
self.session_store = session_store or WorkflowSessionStore()
self.session_controller = session_controller or SessionExecutionController(self.session_store)
self.attachment_service = attachment_service or AttachmentService()
@ -65,6 +66,10 @@ class WebSocketManager:
async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str:
await websocket.accept()
# Capture the event loop that owns the WebSocket connections so that
# worker threads can safely schedule sends via run_coroutine_threadsafe.
if self._owner_loop is None:
self._owner_loop = asyncio.get_running_loop()
if not session_id:
session_id = str(uuid.uuid4())
self.active_connections[session_id] = websocket
@ -108,14 +113,38 @@ class WebSocketManager:
# self.disconnect(session_id)
def send_message_sync(self, session_id: str, message: Dict[str, Any]) -> None:
"""Send a WebSocket message from any thread (including worker threads).
WebSocket objects are bound to the event loop that created them (the main
uvicorn loop). Previous code called ``asyncio.run()`` from worker threads
which spins up a *new* event loop, causing ``RuntimeError: attached to a
different loop`` or silent delivery failures.
The fix: always schedule the coroutine on the loop that owns the sockets
via ``asyncio.run_coroutine_threadsafe`` and wait for the result with a
short timeout so the caller knows if delivery failed.
"""
loop = self._owner_loop
if loop is None or loop.is_closed():
logging.warning(
"Cannot send sync message to %s: owner event loop unavailable",
session_id,
)
return
future = asyncio.run_coroutine_threadsafe(
self.send_message(session_id, message), loop
)
try:
loop = asyncio.get_running_loop()
if loop.is_running():
asyncio.create_task(self.send_message(session_id, message))
else:
asyncio.run(self.send_message(session_id, message))
except RuntimeError:
asyncio.run(self.send_message(session_id, message))
future.result(timeout=10)
except TimeoutError:
logging.warning(
"Timed out sending sync WS message to %s", session_id
)
except Exception as exc:
logging.error(
"Error sending sync WS message to %s: %s", session_id, exc
)
async def broadcast(self, message: Dict[str, Any]) -> None:
for session_id in list(self.active_connections.keys()):

View File

@ -0,0 +1,231 @@
"""Tests for WebSocketManager.send_message_sync cross-thread safety.
Verifies that send_message_sync correctly delivers messages when called
from worker threads (the common case during workflow execution).
The test avoids importing the full server stack (which has circular import
issues) by patching only the WebSocketManager class directly.
"""
import asyncio
import concurrent.futures
import json
import sys
import threading
import time
from typing import List
from unittest.mock import MagicMock
import pytest
# ---------------------------------------------------------------------------
# Isolate WebSocketManager from the circular-import chain
# ---------------------------------------------------------------------------
# Stub out heavy modules so we can import websocket_manager in isolation
_stubs = {}
for mod_name in (
"check", "check.check",
"runtime", "runtime.sdk", "runtime.bootstrap", "runtime.bootstrap.schema",
"server.services.workflow_run_service",
"server.services.message_handler",
"server.services.attachment_service",
"server.services.session_execution",
"server.services.session_store",
"server.services.artifact_events",
):
if mod_name not in sys.modules:
_stubs[mod_name] = MagicMock()
sys.modules[mod_name] = _stubs[mod_name]
from server.services.websocket_manager import WebSocketManager # noqa: E402
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_manager() -> WebSocketManager:
"""Create a WebSocketManager with minimal mocks."""
return WebSocketManager(
session_store=MagicMock(),
session_controller=MagicMock(),
attachment_service=MagicMock(),
workflow_run_service=MagicMock(),
)
class FakeWebSocket:
"""Lightweight fake that records sent messages and the thread they arrived on."""
def __init__(self) -> None:
self.sent: List[str] = []
self.send_threads: List[int] = []
async def accept(self) -> None:
pass
async def send_text(self, data: str) -> None:
self.sent.append(data)
self.send_threads.append(threading.get_ident())
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSendMessageSync:
"""send_message_sync must deliver messages regardless of calling thread."""
def test_send_from_main_thread(self):
"""Message sent from the main (event-loop) thread is delivered."""
manager = _make_manager()
ws = FakeWebSocket()
delivered = []
async def run():
sid = await manager.connect(ws, session_id="s1")
# Drain the initial "connection" message
ws.sent.clear()
manager.send_message_sync(sid, {"type": "test", "data": "hello"})
# Give the scheduled coroutine a moment to execute
await asyncio.sleep(0.05)
delivered.extend(ws.sent)
asyncio.run(run())
assert len(delivered) == 1
assert '"test"' in delivered[0]
def test_send_from_worker_thread(self):
"""Message sent from a background (worker) thread is delivered on the owner loop."""
manager = _make_manager()
ws = FakeWebSocket()
worker_errors: List[Exception] = []
async def run():
sid = await manager.connect(ws, session_id="s2")
ws.sent.clear()
main_thread = threading.get_ident()
def worker():
try:
manager.send_message_sync(sid, {"type": "from_worker"})
except Exception as exc:
worker_errors.append(exc)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(worker)
# Let the worker thread finish and the scheduled coro run
while not future.done():
await asyncio.sleep(0.01)
future.result() # re-raise if worker threw
await asyncio.sleep(0.1)
# Verify delivery
assert len(ws.sent) == 1, f"Expected 1 message, got {len(ws.sent)}"
assert '"from_worker"' in ws.sent[0]
# Verify send_text ran on the main loop thread, not the worker
assert ws.send_threads[0] == main_thread
asyncio.run(run())
assert not worker_errors, f"Worker thread raised: {worker_errors}"
def test_concurrent_workers_no_lost_messages(self):
"""Multiple concurrent workers should each have their message delivered.
In production, the event loop is free while workers run (the main coroutine
awaits ``run_in_executor``). We replicate that by polling workers via
``asyncio.sleep`` so the loop can process the scheduled sends.
"""
manager = _make_manager()
ws = FakeWebSocket()
num_workers = 8
async def run():
sid = await manager.connect(ws, session_id="s3")
ws.sent.clear()
barrier = threading.Barrier(num_workers)
done_count = threading.atomic(0) if hasattr(threading, "atomic") else None
done_flags = [False] * num_workers
def worker(idx: int):
barrier.wait(timeout=5)
manager.send_message_sync(sid, {"type": "msg", "idx": idx})
done_flags[idx] = True
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = [pool.submit(worker, i) for i in range(num_workers)]
# Yield control so the loop can process sends while workers run
deadline = time.time() + 15
while not all(done_flags) and time.time() < deadline:
await asyncio.sleep(0.05)
# Collect any worker exceptions
for f in futures:
f.result(timeout=1)
# Let remaining coros drain
await asyncio.sleep(0.3)
pool.shutdown(wait=False)
assert len(ws.sent) == num_workers, (
f"Expected {num_workers} messages, got {len(ws.sent)}"
)
asyncio.run(run())
def test_send_after_disconnect_does_not_crash(self):
"""Sending after disconnection should not raise."""
manager = _make_manager()
ws = FakeWebSocket()
async def run():
sid = await manager.connect(ws, session_id="s4")
manager.disconnect(sid)
# Should silently skip, not crash
manager.send_message_sync(sid, {"type": "late"})
await asyncio.sleep(0.05)
asyncio.run(run()) # no exception == pass
def test_send_before_any_connection_no_crash(self):
"""Calling send_message_sync before any connect() should not crash."""
manager = _make_manager()
# _owner_loop is None
manager.send_message_sync("nonexistent", {"type": "orphan"})
# Should log a warning, not crash
class TestOwnerLoopCapture:
"""The manager must capture the event loop on first connect."""
def test_owner_loop_captured_on_connect(self):
manager = _make_manager()
ws = FakeWebSocket()
async def run():
assert manager._owner_loop is None
await manager.connect(ws, session_id="cap1")
assert manager._owner_loop is asyncio.get_running_loop()
asyncio.run(run())
def test_owner_loop_stable_across_connections(self):
"""Subsequent connects should not reset the owner loop."""
manager = _make_manager()
ws1 = FakeWebSocket()
ws2 = FakeWebSocket()
async def run():
await manager.connect(ws1, session_id="cap2")
loop1 = manager._owner_loop
await manager.connect(ws2, session_id="cap3")
assert manager._owner_loop is loop1
asyncio.run(run())