mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
WebSocketManager.send_message_sync is called from background worker threads
(via asyncio.get_event_loop().run_in_executor) during workflow execution — by
WebSocketLogger, ArtifactDispatcher, and WebPromptChannel.
Previous implementation:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
asyncio.create_task(...) # path only reachable from main thread
else:
asyncio.run(...) # creates a NEW event loop
except RuntimeError:
asyncio.run(...) # also creates a new event loop
The problem: WebSocket objects are bound to the *main* uvicorn event loop.
asyncio.run() spins up a separate event loop and calls websocket.send_text()
there, which in Python 3.12 raises:
RuntimeError: Task got Future attached to a different loop
...causing all log/artifact/prompt messages emitted from workflow threads to be
silently dropped or to crash the worker thread.
Fix:
- Store the event loop that created the first WebSocket connection as
self._owner_loop (captured in connect(), which always runs on the main loop).
- send_message_sync schedules the coroutine on that loop via
asyncio.run_coroutine_threadsafe(), then waits with a 10 s timeout.
- Calling from the main thread still works (run_coroutine_threadsafe is safe
when called from any thread, including the loop thread itself).
Added 7 tests covering:
- send from main thread
- send from worker thread (verifies send_text runs on the owner loop thread)
- 8 concurrent workers with no lost messages
- send after disconnect does not crash
- send before connect (no owner loop) does not crash
- owner loop captured on first connect
- owner loop stable across multiple connects
232 lines
7.9 KiB
Python
232 lines
7.9 KiB
Python
"""Tests for WebSocketManager.send_message_sync cross-thread safety.
|
|
|
|
Verifies that send_message_sync correctly delivers messages when called
|
|
from worker threads (the common case during workflow execution).
|
|
|
|
The test avoids importing the full server stack (which has circular import
|
|
issues) by patching only the WebSocketManager class directly.
|
|
"""
|
|
|
|
import asyncio
|
|
import concurrent.futures
|
|
import json
|
|
import sys
|
|
import threading
|
|
import time
|
|
from typing import List
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Isolate WebSocketManager from the circular-import chain
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Stub out heavy modules so we can import websocket_manager in isolation
|
|
_stubs = {}
|
|
for mod_name in (
|
|
"check", "check.check",
|
|
"runtime", "runtime.sdk", "runtime.bootstrap", "runtime.bootstrap.schema",
|
|
"server.services.workflow_run_service",
|
|
"server.services.message_handler",
|
|
"server.services.attachment_service",
|
|
"server.services.session_execution",
|
|
"server.services.session_store",
|
|
"server.services.artifact_events",
|
|
):
|
|
if mod_name not in sys.modules:
|
|
_stubs[mod_name] = MagicMock()
|
|
sys.modules[mod_name] = _stubs[mod_name]
|
|
|
|
from server.services.websocket_manager import WebSocketManager # noqa: E402
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_manager() -> WebSocketManager:
|
|
"""Create a WebSocketManager with minimal mocks."""
|
|
return WebSocketManager(
|
|
session_store=MagicMock(),
|
|
session_controller=MagicMock(),
|
|
attachment_service=MagicMock(),
|
|
workflow_run_service=MagicMock(),
|
|
)
|
|
|
|
|
|
class FakeWebSocket:
|
|
"""Lightweight fake that records sent messages and the thread they arrived on."""
|
|
|
|
def __init__(self) -> None:
|
|
self.sent: List[str] = []
|
|
self.send_threads: List[int] = []
|
|
|
|
async def accept(self) -> None:
|
|
pass
|
|
|
|
async def send_text(self, data: str) -> None:
|
|
self.sent.append(data)
|
|
self.send_threads.append(threading.get_ident())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSendMessageSync:
|
|
"""send_message_sync must deliver messages regardless of calling thread."""
|
|
|
|
def test_send_from_main_thread(self):
|
|
"""Message sent from the main (event-loop) thread is delivered."""
|
|
manager = _make_manager()
|
|
ws = FakeWebSocket()
|
|
delivered = []
|
|
|
|
async def run():
|
|
sid = await manager.connect(ws, session_id="s1")
|
|
# Drain the initial "connection" message
|
|
ws.sent.clear()
|
|
|
|
manager.send_message_sync(sid, {"type": "test", "data": "hello"})
|
|
# Give the scheduled coroutine a moment to execute
|
|
await asyncio.sleep(0.05)
|
|
delivered.extend(ws.sent)
|
|
|
|
asyncio.run(run())
|
|
assert len(delivered) == 1
|
|
assert '"test"' in delivered[0]
|
|
|
|
def test_send_from_worker_thread(self):
|
|
"""Message sent from a background (worker) thread is delivered on the owner loop."""
|
|
manager = _make_manager()
|
|
ws = FakeWebSocket()
|
|
worker_errors: List[Exception] = []
|
|
|
|
async def run():
|
|
sid = await manager.connect(ws, session_id="s2")
|
|
ws.sent.clear()
|
|
main_thread = threading.get_ident()
|
|
|
|
def worker():
|
|
try:
|
|
manager.send_message_sync(sid, {"type": "from_worker"})
|
|
except Exception as exc:
|
|
worker_errors.append(exc)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
future = pool.submit(worker)
|
|
# Let the worker thread finish and the scheduled coro run
|
|
while not future.done():
|
|
await asyncio.sleep(0.01)
|
|
future.result() # re-raise if worker threw
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Verify delivery
|
|
assert len(ws.sent) == 1, f"Expected 1 message, got {len(ws.sent)}"
|
|
assert '"from_worker"' in ws.sent[0]
|
|
|
|
# Verify send_text ran on the main loop thread, not the worker
|
|
assert ws.send_threads[0] == main_thread
|
|
|
|
asyncio.run(run())
|
|
assert not worker_errors, f"Worker thread raised: {worker_errors}"
|
|
|
|
def test_concurrent_workers_no_lost_messages(self):
|
|
"""Multiple concurrent workers should each have their message delivered.
|
|
|
|
In production, the event loop is free while workers run (the main coroutine
|
|
awaits ``run_in_executor``). We replicate that by polling workers via
|
|
``asyncio.sleep`` so the loop can process the scheduled sends.
|
|
"""
|
|
manager = _make_manager()
|
|
ws = FakeWebSocket()
|
|
num_workers = 8
|
|
|
|
async def run():
|
|
sid = await manager.connect(ws, session_id="s3")
|
|
ws.sent.clear()
|
|
|
|
barrier = threading.Barrier(num_workers)
|
|
done_count = threading.atomic(0) if hasattr(threading, "atomic") else None
|
|
done_flags = [False] * num_workers
|
|
|
|
def worker(idx: int):
|
|
barrier.wait(timeout=5)
|
|
manager.send_message_sync(sid, {"type": "msg", "idx": idx})
|
|
done_flags[idx] = True
|
|
|
|
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
|
|
futures = [pool.submit(worker, i) for i in range(num_workers)]
|
|
|
|
# Yield control so the loop can process sends while workers run
|
|
deadline = time.time() + 15
|
|
while not all(done_flags) and time.time() < deadline:
|
|
await asyncio.sleep(0.05)
|
|
|
|
# Collect any worker exceptions
|
|
for f in futures:
|
|
f.result(timeout=1)
|
|
|
|
# Let remaining coros drain
|
|
await asyncio.sleep(0.3)
|
|
pool.shutdown(wait=False)
|
|
|
|
assert len(ws.sent) == num_workers, (
|
|
f"Expected {num_workers} messages, got {len(ws.sent)}"
|
|
)
|
|
|
|
asyncio.run(run())
|
|
|
|
def test_send_after_disconnect_does_not_crash(self):
|
|
"""Sending after disconnection should not raise."""
|
|
manager = _make_manager()
|
|
ws = FakeWebSocket()
|
|
|
|
async def run():
|
|
sid = await manager.connect(ws, session_id="s4")
|
|
manager.disconnect(sid)
|
|
|
|
# Should silently skip, not crash
|
|
manager.send_message_sync(sid, {"type": "late"})
|
|
await asyncio.sleep(0.05)
|
|
|
|
asyncio.run(run()) # no exception == pass
|
|
|
|
def test_send_before_any_connection_no_crash(self):
|
|
"""Calling send_message_sync before any connect() should not crash."""
|
|
manager = _make_manager()
|
|
# _owner_loop is None
|
|
manager.send_message_sync("nonexistent", {"type": "orphan"})
|
|
# Should log a warning, not crash
|
|
|
|
|
|
class TestOwnerLoopCapture:
|
|
"""The manager must capture the event loop on first connect."""
|
|
|
|
def test_owner_loop_captured_on_connect(self):
|
|
manager = _make_manager()
|
|
ws = FakeWebSocket()
|
|
|
|
async def run():
|
|
assert manager._owner_loop is None
|
|
await manager.connect(ws, session_id="cap1")
|
|
assert manager._owner_loop is asyncio.get_running_loop()
|
|
|
|
asyncio.run(run())
|
|
|
|
def test_owner_loop_stable_across_connections(self):
|
|
"""Subsequent connects should not reset the owner loop."""
|
|
manager = _make_manager()
|
|
ws1 = FakeWebSocket()
|
|
ws2 = FakeWebSocket()
|
|
|
|
async def run():
|
|
await manager.connect(ws1, session_id="cap2")
|
|
loop1 = manager._owner_loop
|
|
await manager.connect(ws2, session_id="cap3")
|
|
assert manager._owner_loop is loop1
|
|
|
|
asyncio.run(run())
|