ChatDev/tests/test_websocket_send_message_sync.py
voidborne-d 511d05e545 fix: use run_coroutine_threadsafe in send_message_sync to prevent cross-loop crash
WebSocketManager.send_message_sync is called from background worker threads
(via asyncio.get_event_loop().run_in_executor) during workflow execution — by
WebSocketLogger, ArtifactDispatcher, and WebPromptChannel.

Previous implementation:
  try:
      loop = asyncio.get_running_loop()
      if loop.is_running():
          asyncio.create_task(...)          # path only reachable from main thread
      else:
          asyncio.run(...)                  # creates a NEW event loop
  except RuntimeError:
      asyncio.run(...)                      # also creates a new event loop

The problem: WebSocket objects are bound to the *main* uvicorn event loop.
asyncio.run() spins up a separate event loop and calls websocket.send_text()
there, which in Python 3.12 raises:

  RuntimeError: Task got Future attached to a different loop

...causing all log/artifact/prompt messages emitted from workflow threads to be
silently dropped or to crash the worker thread.

Fix:
- Store the event loop that created the first WebSocket connection as
  self._owner_loop (captured in connect(), which always runs on the main loop).
- send_message_sync schedules the coroutine on that loop via
  asyncio.run_coroutine_threadsafe(), then waits with a 10 s timeout.
- Calling from the main thread still works (run_coroutine_threadsafe is safe
  when called from any thread, including the loop thread itself).

Added 7 tests covering:
- send from main thread
- send from worker thread (verifies send_text runs on the owner loop thread)
- 8 concurrent workers with no lost messages
- send after disconnect does not crash
- send before connect (no owner loop) does not crash
- owner loop captured on first connect
- owner loop stable across multiple connects
2026-04-02 16:01:46 +00:00

232 lines
7.9 KiB
Python

"""Tests for WebSocketManager.send_message_sync cross-thread safety.
Verifies that send_message_sync correctly delivers messages when called
from worker threads (the common case during workflow execution).
The test avoids importing the full server stack (which has circular import
issues) by patching only the WebSocketManager class directly.
"""
import asyncio
import concurrent.futures
import json
import sys
import threading
import time
from typing import List
from unittest.mock import MagicMock
import pytest
# ---------------------------------------------------------------------------
# Isolate WebSocketManager from the circular-import chain
# ---------------------------------------------------------------------------
# Stub out heavy modules so we can import websocket_manager in isolation
_stubs = {}
for mod_name in (
"check", "check.check",
"runtime", "runtime.sdk", "runtime.bootstrap", "runtime.bootstrap.schema",
"server.services.workflow_run_service",
"server.services.message_handler",
"server.services.attachment_service",
"server.services.session_execution",
"server.services.session_store",
"server.services.artifact_events",
):
if mod_name not in sys.modules:
_stubs[mod_name] = MagicMock()
sys.modules[mod_name] = _stubs[mod_name]
from server.services.websocket_manager import WebSocketManager # noqa: E402
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_manager() -> WebSocketManager:
"""Create a WebSocketManager with minimal mocks."""
return WebSocketManager(
session_store=MagicMock(),
session_controller=MagicMock(),
attachment_service=MagicMock(),
workflow_run_service=MagicMock(),
)
class FakeWebSocket:
"""Lightweight fake that records sent messages and the thread they arrived on."""
def __init__(self) -> None:
self.sent: List[str] = []
self.send_threads: List[int] = []
async def accept(self) -> None:
pass
async def send_text(self, data: str) -> None:
self.sent.append(data)
self.send_threads.append(threading.get_ident())
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSendMessageSync:
"""send_message_sync must deliver messages regardless of calling thread."""
def test_send_from_main_thread(self):
"""Message sent from the main (event-loop) thread is delivered."""
manager = _make_manager()
ws = FakeWebSocket()
delivered = []
async def run():
sid = await manager.connect(ws, session_id="s1")
# Drain the initial "connection" message
ws.sent.clear()
manager.send_message_sync(sid, {"type": "test", "data": "hello"})
# Give the scheduled coroutine a moment to execute
await asyncio.sleep(0.05)
delivered.extend(ws.sent)
asyncio.run(run())
assert len(delivered) == 1
assert '"test"' in delivered[0]
def test_send_from_worker_thread(self):
"""Message sent from a background (worker) thread is delivered on the owner loop."""
manager = _make_manager()
ws = FakeWebSocket()
worker_errors: List[Exception] = []
async def run():
sid = await manager.connect(ws, session_id="s2")
ws.sent.clear()
main_thread = threading.get_ident()
def worker():
try:
manager.send_message_sync(sid, {"type": "from_worker"})
except Exception as exc:
worker_errors.append(exc)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(worker)
# Let the worker thread finish and the scheduled coro run
while not future.done():
await asyncio.sleep(0.01)
future.result() # re-raise if worker threw
await asyncio.sleep(0.1)
# Verify delivery
assert len(ws.sent) == 1, f"Expected 1 message, got {len(ws.sent)}"
assert '"from_worker"' in ws.sent[0]
# Verify send_text ran on the main loop thread, not the worker
assert ws.send_threads[0] == main_thread
asyncio.run(run())
assert not worker_errors, f"Worker thread raised: {worker_errors}"
def test_concurrent_workers_no_lost_messages(self):
"""Multiple concurrent workers should each have their message delivered.
In production, the event loop is free while workers run (the main coroutine
awaits ``run_in_executor``). We replicate that by polling workers via
``asyncio.sleep`` so the loop can process the scheduled sends.
"""
manager = _make_manager()
ws = FakeWebSocket()
num_workers = 8
async def run():
sid = await manager.connect(ws, session_id="s3")
ws.sent.clear()
barrier = threading.Barrier(num_workers)
done_count = threading.atomic(0) if hasattr(threading, "atomic") else None
done_flags = [False] * num_workers
def worker(idx: int):
barrier.wait(timeout=5)
manager.send_message_sync(sid, {"type": "msg", "idx": idx})
done_flags[idx] = True
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = [pool.submit(worker, i) for i in range(num_workers)]
# Yield control so the loop can process sends while workers run
deadline = time.time() + 15
while not all(done_flags) and time.time() < deadline:
await asyncio.sleep(0.05)
# Collect any worker exceptions
for f in futures:
f.result(timeout=1)
# Let remaining coros drain
await asyncio.sleep(0.3)
pool.shutdown(wait=False)
assert len(ws.sent) == num_workers, (
f"Expected {num_workers} messages, got {len(ws.sent)}"
)
asyncio.run(run())
def test_send_after_disconnect_does_not_crash(self):
"""Sending after disconnection should not raise."""
manager = _make_manager()
ws = FakeWebSocket()
async def run():
sid = await manager.connect(ws, session_id="s4")
manager.disconnect(sid)
# Should silently skip, not crash
manager.send_message_sync(sid, {"type": "late"})
await asyncio.sleep(0.05)
asyncio.run(run()) # no exception == pass
def test_send_before_any_connection_no_crash(self):
"""Calling send_message_sync before any connect() should not crash."""
manager = _make_manager()
# _owner_loop is None
manager.send_message_sync("nonexistent", {"type": "orphan"})
# Should log a warning, not crash
class TestOwnerLoopCapture:
"""The manager must capture the event loop on first connect."""
def test_owner_loop_captured_on_connect(self):
manager = _make_manager()
ws = FakeWebSocket()
async def run():
assert manager._owner_loop is None
await manager.connect(ws, session_id="cap1")
assert manager._owner_loop is asyncio.get_running_loop()
asyncio.run(run())
def test_owner_loop_stable_across_connections(self):
"""Subsequent connects should not reset the owner loop."""
manager = _make_manager()
ws1 = FakeWebSocket()
ws2 = FakeWebSocket()
async def run():
await manager.connect(ws1, session_id="cap2")
loop1 = manager._owner_loop
await manager.connect(ws2, session_id="cap3")
assert manager._owner_loop is loop1
asyncio.run(run())