mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
169 lines
6.5 KiB
Python
Executable File
169 lines
6.5 KiB
Python
Executable File
"""WebSocket connection manager used by FastAPI app."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from typing import Any, Dict, Optional
|
|
|
|
from fastapi import WebSocket
|
|
|
|
from server.services.message_handler import MessageHandler
|
|
from server.services.attachment_service import AttachmentService
|
|
from server.services.session_execution import SessionExecutionController
|
|
from server.services.session_store import WorkflowSessionStore, SessionStatus
|
|
from server.services.workflow_run_service import WorkflowRunService
|
|
|
|
|
|
def _json_default(value):
|
|
to_dict = getattr(value, "to_dict", None)
|
|
if callable(to_dict):
|
|
try:
|
|
return to_dict()
|
|
except Exception:
|
|
pass
|
|
if hasattr(value, "__dict__"):
|
|
try:
|
|
return vars(value)
|
|
except Exception:
|
|
pass
|
|
return str(value)
|
|
|
|
|
|
def _encode_ws_message(message: Any) -> str:
|
|
if isinstance(message, str):
|
|
return message
|
|
return json.dumps(message, default=_json_default)
|
|
|
|
|
|
class WebSocketManager:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
session_store: WorkflowSessionStore | None = None,
|
|
session_controller: SessionExecutionController | None = None,
|
|
attachment_service: AttachmentService | None = None,
|
|
workflow_run_service: WorkflowRunService | None = None,
|
|
):
|
|
self.active_connections: Dict[str, WebSocket] = {}
|
|
self.connection_timestamps: Dict[str, float] = {}
|
|
self.send_locks: Dict[str, asyncio.Lock] = {}
|
|
self.loop: asyncio.AbstractEventLoop | None = None
|
|
self.session_store = session_store or WorkflowSessionStore()
|
|
self.session_controller = session_controller or SessionExecutionController(self.session_store)
|
|
self.attachment_service = attachment_service or AttachmentService()
|
|
self.workflow_run_service = workflow_run_service or WorkflowRunService(
|
|
self.session_store,
|
|
self.session_controller,
|
|
self.attachment_service,
|
|
)
|
|
self.message_handler = MessageHandler(
|
|
self.session_store,
|
|
self.session_controller,
|
|
self.workflow_run_service,
|
|
)
|
|
|
|
async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str:
|
|
await websocket.accept()
|
|
if self.loop is None:
|
|
try:
|
|
self.loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
self.loop = None
|
|
if not session_id:
|
|
session_id = str(uuid.uuid4())
|
|
self.active_connections[session_id] = websocket
|
|
self.connection_timestamps[session_id] = time.time()
|
|
self.send_locks[session_id] = asyncio.Lock()
|
|
logging.info("WebSocket connected: %s", session_id)
|
|
await self.send_message(
|
|
session_id,
|
|
{
|
|
"type": "connection",
|
|
"data": {"session_id": session_id, "status": "connected"},
|
|
},
|
|
)
|
|
return session_id
|
|
|
|
def disconnect(self, session_id: str) -> None:
|
|
session = self.session_store.get_session(session_id)
|
|
if session and session.status in {SessionStatus.RUNNING, SessionStatus.WAITING_FOR_INPUT}:
|
|
self.workflow_run_service.request_cancel(
|
|
session_id,
|
|
reason="WebSocket disconnected",
|
|
)
|
|
if session_id in self.active_connections:
|
|
del self.active_connections[session_id]
|
|
if session_id in self.connection_timestamps:
|
|
del self.connection_timestamps[session_id]
|
|
if session_id in self.send_locks:
|
|
del self.send_locks[session_id]
|
|
self.session_controller.cleanup_session(session_id)
|
|
remaining_session = self.session_store.get_session(session_id)
|
|
if remaining_session and remaining_session.executor is None:
|
|
self.session_store.pop_session(session_id)
|
|
self.attachment_service.cleanup_session(session_id)
|
|
logging.info("WebSocket disconnected: %s", session_id)
|
|
|
|
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
|
|
if session_id in self.active_connections:
|
|
websocket = self.active_connections[session_id]
|
|
try:
|
|
lock = self.send_locks.get(session_id)
|
|
if lock is None:
|
|
await websocket.send_text(_encode_ws_message(message))
|
|
else:
|
|
async with lock:
|
|
await websocket.send_text(_encode_ws_message(message))
|
|
except Exception as exc:
|
|
traceback.print_exc()
|
|
logging.error("Failed to send message to %s: %s", session_id, exc)
|
|
# self.disconnect(session_id)
|
|
|
|
def send_message_sync(self, session_id: str, message: Dict[str, Any]) -> None:
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
if loop.is_running():
|
|
asyncio.create_task(self.send_message(session_id, message))
|
|
else:
|
|
asyncio.run(self.send_message(session_id, message))
|
|
except RuntimeError:
|
|
if self.loop and self.loop.is_running():
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.send_message(session_id, message),
|
|
self.loop,
|
|
)
|
|
else:
|
|
asyncio.run(self.send_message(session_id, message))
|
|
|
|
async def broadcast(self, message: Dict[str, Any]) -> None:
|
|
for session_id in list(self.active_connections.keys()):
|
|
await self.send_message(session_id, message)
|
|
|
|
async def handle_heartbeat(self, session_id: str) -> None:
|
|
if session_id in self.active_connections:
|
|
await self.send_message(
|
|
session_id,
|
|
{"type": "pong", "data": {"timestamp": time.time()}},
|
|
)
|
|
else:
|
|
logging.warning("Heartbeat request from disconnected session: %s", session_id)
|
|
|
|
async def handle_message(self, session_id: str, message: str) -> None:
|
|
try:
|
|
data = json.loads(message)
|
|
await self.message_handler.handle_message(session_id, data, self)
|
|
except json.JSONDecodeError:
|
|
await self.send_message(
|
|
session_id,
|
|
{"type": "error", "data": {"message": "Invalid JSON format"}},
|
|
)
|
|
except Exception as exc:
|
|
logging.error("Error handling message from %s: %s", session_id, exc)
|
|
await self.send_message(
|
|
session_id,
|
|
{"type": "error", "data": {"message": str(exc)}},
|
|
)
|