ChatDev/server/services/websocket_manager.py
voidborne-d 511d05e545 fix: use run_coroutine_threadsafe in send_message_sync to prevent cross-loop crash
WebSocketManager.send_message_sync is called from background worker threads
(via asyncio.get_event_loop().run_in_executor) during workflow execution — by
WebSocketLogger, ArtifactDispatcher, and WebPromptChannel.

Previous implementation:
  try:
      loop = asyncio.get_running_loop()
      if loop.is_running():
          asyncio.create_task(...)          # path only reachable from main thread
      else:
          asyncio.run(...)                  # creates a NEW event loop
  except RuntimeError:
      asyncio.run(...)                      # also creates a new event loop

The problem: WebSocket objects are bound to the *main* uvicorn event loop.
asyncio.run() spins up a separate event loop and calls websocket.send_text()
there, which in Python 3.12 raises:

  RuntimeError: Task got Future attached to a different loop

...causing all log/artifact/prompt messages emitted from workflow threads to be
silently dropped or to crash the worker thread.

Fix:
- Store the event loop that created the first WebSocket connection as
  self._owner_loop (captured in connect(), which always runs on the main loop).
- send_message_sync schedules the coroutine on that loop via
  asyncio.run_coroutine_threadsafe(), then waits with a 10 s timeout.
- Calling from the main thread still works (run_coroutine_threadsafe is safe
  when called from any thread, including the loop thread itself).

Added 7 tests covering:
- send from main thread
- send from worker thread (verifies send_text runs on the owner loop thread)
- 8 concurrent workers with no lost messages
- send after disconnect does not crash
- send before connect (no owner loop) does not crash
- owner loop captured on first connect
- owner loop stable across multiple connects
2026-04-02 16:01:46 +00:00

177 lines
6.9 KiB
Python
Executable File

"""WebSocket connection manager used by FastAPI app."""
import asyncio
import json
import logging
import time
import traceback
import uuid
from typing import Any, Dict, Optional
from fastapi import WebSocket
from server.services.message_handler import MessageHandler
from server.services.attachment_service import AttachmentService
from server.services.session_execution import SessionExecutionController
from server.services.session_store import WorkflowSessionStore, SessionStatus
from server.services.workflow_run_service import WorkflowRunService
def _json_default(value):
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
try:
return to_dict()
except Exception:
pass
if hasattr(value, "__dict__"):
try:
return vars(value)
except Exception:
pass
return str(value)
def _encode_ws_message(message: Any) -> str:
if isinstance(message, str):
return message
return json.dumps(message, default=_json_default)
class WebSocketManager:
def __init__(
self,
*,
session_store: WorkflowSessionStore | None = None,
session_controller: SessionExecutionController | None = None,
attachment_service: AttachmentService | None = None,
workflow_run_service: WorkflowRunService | None = None,
):
self.active_connections: Dict[str, WebSocket] = {}
self.connection_timestamps: Dict[str, float] = {}
self._owner_loop: Optional[asyncio.AbstractEventLoop] = None
self.session_store = session_store or WorkflowSessionStore()
self.session_controller = session_controller or SessionExecutionController(self.session_store)
self.attachment_service = attachment_service or AttachmentService()
self.workflow_run_service = workflow_run_service or WorkflowRunService(
self.session_store,
self.session_controller,
self.attachment_service,
)
self.message_handler = MessageHandler(
self.session_store,
self.session_controller,
self.workflow_run_service,
)
async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str:
await websocket.accept()
# Capture the event loop that owns the WebSocket connections so that
# worker threads can safely schedule sends via run_coroutine_threadsafe.
if self._owner_loop is None:
self._owner_loop = asyncio.get_running_loop()
if not session_id:
session_id = str(uuid.uuid4())
self.active_connections[session_id] = websocket
self.connection_timestamps[session_id] = time.time()
logging.info("WebSocket connected: %s", session_id)
await self.send_message(
session_id,
{
"type": "connection",
"data": {"session_id": session_id, "status": "connected"},
},
)
return session_id
def disconnect(self, session_id: str) -> None:
session = self.session_store.get_session(session_id)
if session and session.status in {SessionStatus.RUNNING, SessionStatus.WAITING_FOR_INPUT}:
self.workflow_run_service.request_cancel(
session_id,
reason="WebSocket disconnected",
)
if session_id in self.active_connections:
del self.active_connections[session_id]
if session_id in self.connection_timestamps:
del self.connection_timestamps[session_id]
self.session_controller.cleanup_session(session_id)
remaining_session = self.session_store.get_session(session_id)
if remaining_session and remaining_session.executor is None:
self.session_store.pop_session(session_id)
self.attachment_service.cleanup_session(session_id)
logging.info("WebSocket disconnected: %s", session_id)
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
if session_id in self.active_connections:
websocket = self.active_connections[session_id]
try:
await websocket.send_text(_encode_ws_message(message))
except Exception as exc:
traceback.print_exc()
logging.error("Failed to send message to %s: %s", session_id, exc)
# self.disconnect(session_id)
def send_message_sync(self, session_id: str, message: Dict[str, Any]) -> None:
"""Send a WebSocket message from any thread (including worker threads).
WebSocket objects are bound to the event loop that created them (the main
uvicorn loop). Previous code called ``asyncio.run()`` from worker threads
which spins up a *new* event loop, causing ``RuntimeError: … attached to a
different loop`` or silent delivery failures.
The fix: always schedule the coroutine on the loop that owns the sockets
via ``asyncio.run_coroutine_threadsafe`` and wait for the result with a
short timeout so the caller knows if delivery failed.
"""
loop = self._owner_loop
if loop is None or loop.is_closed():
logging.warning(
"Cannot send sync message to %s: owner event loop unavailable",
session_id,
)
return
future = asyncio.run_coroutine_threadsafe(
self.send_message(session_id, message), loop
)
try:
future.result(timeout=10)
except TimeoutError:
logging.warning(
"Timed out sending sync WS message to %s", session_id
)
except Exception as exc:
logging.error(
"Error sending sync WS message to %s: %s", session_id, exc
)
async def broadcast(self, message: Dict[str, Any]) -> None:
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
async def handle_heartbeat(self, session_id: str) -> None:
if session_id in self.active_connections:
await self.send_message(
session_id,
{"type": "pong", "data": {"timestamp": time.time()}},
)
else:
logging.warning("Heartbeat request from disconnected session: %s", session_id)
async def handle_message(self, session_id: str, message: str) -> None:
try:
data = json.loads(message)
await self.message_handler.handle_message(session_id, data, self)
except json.JSONDecodeError:
await self.send_message(
session_id,
{"type": "error", "data": {"message": "Invalid JSON format"}},
)
except Exception as exc:
logging.error("Error handling message from %s: %s", session_id, exc)
await self.send_message(
session_id,
{"type": "error", "data": {"message": str(exc)}},
)