mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
Merge pull request #600 from voidborne-d/fix/send-message-sync-cross-loop
fix: use run_coroutine_threadsafe in send_message_sync to prevent cross-loop RuntimeError
This commit is contained in:
commit
117e4d8d0c
@ -49,6 +49,7 @@ class WebSocketManager:
|
||||
):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.connection_timestamps: Dict[str, float] = {}
|
||||
self._owner_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self.session_store = session_store or WorkflowSessionStore()
|
||||
self.session_controller = session_controller or SessionExecutionController(self.session_store)
|
||||
self.attachment_service = attachment_service or AttachmentService()
|
||||
@ -65,6 +66,10 @@ class WebSocketManager:
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str:
|
||||
await websocket.accept()
|
||||
# Capture the event loop that owns the WebSocket connections so that
|
||||
# worker threads can safely schedule sends via run_coroutine_threadsafe.
|
||||
if self._owner_loop is None:
|
||||
self._owner_loop = asyncio.get_running_loop()
|
||||
if not session_id:
|
||||
session_id = str(uuid.uuid4())
|
||||
self.active_connections[session_id] = websocket
|
||||
@ -108,14 +113,38 @@ class WebSocketManager:
|
||||
# self.disconnect(session_id)
|
||||
|
||||
def send_message_sync(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
"""Send a WebSocket message from any thread (including worker threads).
|
||||
|
||||
WebSocket objects are bound to the event loop that created them (the main
|
||||
uvicorn loop). Previous code called ``asyncio.run()`` from worker threads
|
||||
which spins up a *new* event loop, causing ``RuntimeError: … attached to a
|
||||
different loop`` or silent delivery failures.
|
||||
|
||||
The fix: always schedule the coroutine on the loop that owns the sockets
|
||||
via ``asyncio.run_coroutine_threadsafe`` and wait for the result with a
|
||||
short timeout so the caller knows if delivery failed.
|
||||
"""
|
||||
loop = self._owner_loop
|
||||
if loop is None or loop.is_closed():
|
||||
logging.warning(
|
||||
"Cannot send sync message to %s: owner event loop unavailable",
|
||||
session_id,
|
||||
)
|
||||
return
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.send_message(session_id, message), loop
|
||||
)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_running():
|
||||
asyncio.create_task(self.send_message(session_id, message))
|
||||
else:
|
||||
asyncio.run(self.send_message(session_id, message))
|
||||
except RuntimeError:
|
||||
asyncio.run(self.send_message(session_id, message))
|
||||
future.result(timeout=10)
|
||||
except TimeoutError:
|
||||
logging.warning(
|
||||
"Timed out sending sync WS message to %s", session_id
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error(
|
||||
"Error sending sync WS message to %s: %s", session_id, exc
|
||||
)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
|
||||
231
tests/test_websocket_send_message_sync.py
Normal file
231
tests/test_websocket_send_message_sync.py
Normal file
@ -0,0 +1,231 @@
|
||||
"""Tests for WebSocketManager.send_message_sync cross-thread safety.
|
||||
|
||||
Verifies that send_message_sync correctly delivers messages when called
|
||||
from worker threads (the common case during workflow execution).
|
||||
|
||||
The test avoids importing the full server stack (which has circular import
|
||||
issues) by patching only the WebSocketManager class directly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Isolate WebSocketManager from the circular-import chain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Stub out heavy modules so we can import websocket_manager in isolation
|
||||
_stubs = {}
|
||||
for mod_name in (
|
||||
"check", "check.check",
|
||||
"runtime", "runtime.sdk", "runtime.bootstrap", "runtime.bootstrap.schema",
|
||||
"server.services.workflow_run_service",
|
||||
"server.services.message_handler",
|
||||
"server.services.attachment_service",
|
||||
"server.services.session_execution",
|
||||
"server.services.session_store",
|
||||
"server.services.artifact_events",
|
||||
):
|
||||
if mod_name not in sys.modules:
|
||||
_stubs[mod_name] = MagicMock()
|
||||
sys.modules[mod_name] = _stubs[mod_name]
|
||||
|
||||
from server.services.websocket_manager import WebSocketManager # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_manager() -> WebSocketManager:
|
||||
"""Create a WebSocketManager with minimal mocks."""
|
||||
return WebSocketManager(
|
||||
session_store=MagicMock(),
|
||||
session_controller=MagicMock(),
|
||||
attachment_service=MagicMock(),
|
||||
workflow_run_service=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
class FakeWebSocket:
|
||||
"""Lightweight fake that records sent messages and the thread they arrived on."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent: List[str] = []
|
||||
self.send_threads: List[int] = []
|
||||
|
||||
async def accept(self) -> None:
|
||||
pass
|
||||
|
||||
async def send_text(self, data: str) -> None:
|
||||
self.sent.append(data)
|
||||
self.send_threads.append(threading.get_ident())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendMessageSync:
|
||||
"""send_message_sync must deliver messages regardless of calling thread."""
|
||||
|
||||
def test_send_from_main_thread(self):
|
||||
"""Message sent from the main (event-loop) thread is delivered."""
|
||||
manager = _make_manager()
|
||||
ws = FakeWebSocket()
|
||||
delivered = []
|
||||
|
||||
async def run():
|
||||
sid = await manager.connect(ws, session_id="s1")
|
||||
# Drain the initial "connection" message
|
||||
ws.sent.clear()
|
||||
|
||||
manager.send_message_sync(sid, {"type": "test", "data": "hello"})
|
||||
# Give the scheduled coroutine a moment to execute
|
||||
await asyncio.sleep(0.05)
|
||||
delivered.extend(ws.sent)
|
||||
|
||||
asyncio.run(run())
|
||||
assert len(delivered) == 1
|
||||
assert '"test"' in delivered[0]
|
||||
|
||||
def test_send_from_worker_thread(self):
|
||||
"""Message sent from a background (worker) thread is delivered on the owner loop."""
|
||||
manager = _make_manager()
|
||||
ws = FakeWebSocket()
|
||||
worker_errors: List[Exception] = []
|
||||
|
||||
async def run():
|
||||
sid = await manager.connect(ws, session_id="s2")
|
||||
ws.sent.clear()
|
||||
main_thread = threading.get_ident()
|
||||
|
||||
def worker():
|
||||
try:
|
||||
manager.send_message_sync(sid, {"type": "from_worker"})
|
||||
except Exception as exc:
|
||||
worker_errors.append(exc)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(worker)
|
||||
# Let the worker thread finish and the scheduled coro run
|
||||
while not future.done():
|
||||
await asyncio.sleep(0.01)
|
||||
future.result() # re-raise if worker threw
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Verify delivery
|
||||
assert len(ws.sent) == 1, f"Expected 1 message, got {len(ws.sent)}"
|
||||
assert '"from_worker"' in ws.sent[0]
|
||||
|
||||
# Verify send_text ran on the main loop thread, not the worker
|
||||
assert ws.send_threads[0] == main_thread
|
||||
|
||||
asyncio.run(run())
|
||||
assert not worker_errors, f"Worker thread raised: {worker_errors}"
|
||||
|
||||
def test_concurrent_workers_no_lost_messages(self):
|
||||
"""Multiple concurrent workers should each have their message delivered.
|
||||
|
||||
In production, the event loop is free while workers run (the main coroutine
|
||||
awaits ``run_in_executor``). We replicate that by polling workers via
|
||||
``asyncio.sleep`` so the loop can process the scheduled sends.
|
||||
"""
|
||||
manager = _make_manager()
|
||||
ws = FakeWebSocket()
|
||||
num_workers = 8
|
||||
|
||||
async def run():
|
||||
sid = await manager.connect(ws, session_id="s3")
|
||||
ws.sent.clear()
|
||||
|
||||
barrier = threading.Barrier(num_workers)
|
||||
done_count = threading.atomic(0) if hasattr(threading, "atomic") else None
|
||||
done_flags = [False] * num_workers
|
||||
|
||||
def worker(idx: int):
|
||||
barrier.wait(timeout=5)
|
||||
manager.send_message_sync(sid, {"type": "msg", "idx": idx})
|
||||
done_flags[idx] = True
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
|
||||
futures = [pool.submit(worker, i) for i in range(num_workers)]
|
||||
|
||||
# Yield control so the loop can process sends while workers run
|
||||
deadline = time.time() + 15
|
||||
while not all(done_flags) and time.time() < deadline:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Collect any worker exceptions
|
||||
for f in futures:
|
||||
f.result(timeout=1)
|
||||
|
||||
# Let remaining coros drain
|
||||
await asyncio.sleep(0.3)
|
||||
pool.shutdown(wait=False)
|
||||
|
||||
assert len(ws.sent) == num_workers, (
|
||||
f"Expected {num_workers} messages, got {len(ws.sent)}"
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_send_after_disconnect_does_not_crash(self):
|
||||
"""Sending after disconnection should not raise."""
|
||||
manager = _make_manager()
|
||||
ws = FakeWebSocket()
|
||||
|
||||
async def run():
|
||||
sid = await manager.connect(ws, session_id="s4")
|
||||
manager.disconnect(sid)
|
||||
|
||||
# Should silently skip, not crash
|
||||
manager.send_message_sync(sid, {"type": "late"})
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
asyncio.run(run()) # no exception == pass
|
||||
|
||||
def test_send_before_any_connection_no_crash(self):
|
||||
"""Calling send_message_sync before any connect() should not crash."""
|
||||
manager = _make_manager()
|
||||
# _owner_loop is None
|
||||
manager.send_message_sync("nonexistent", {"type": "orphan"})
|
||||
# Should log a warning, not crash
|
||||
|
||||
|
||||
class TestOwnerLoopCapture:
|
||||
"""The manager must capture the event loop on first connect."""
|
||||
|
||||
def test_owner_loop_captured_on_connect(self):
|
||||
manager = _make_manager()
|
||||
ws = FakeWebSocket()
|
||||
|
||||
async def run():
|
||||
assert manager._owner_loop is None
|
||||
await manager.connect(ws, session_id="cap1")
|
||||
assert manager._owner_loop is asyncio.get_running_loop()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_owner_loop_stable_across_connections(self):
|
||||
"""Subsequent connects should not reset the owner loop."""
|
||||
manager = _make_manager()
|
||||
ws1 = FakeWebSocket()
|
||||
ws2 = FakeWebSocket()
|
||||
|
||||
async def run():
|
||||
await manager.connect(ws1, session_id="cap2")
|
||||
loop1 = manager._owner_loop
|
||||
await manager.connect(ws2, session_id="cap3")
|
||||
assert manager._owner_loop is loop1
|
||||
|
||||
asyncio.run(run())
|
||||
Loading…
x
Reference in New Issue
Block a user