ChatDev/server/services/websocket_manager.py
2026-01-07 16:24:01 +08:00

148 lines
5.6 KiB
Python
Executable File

"""WebSocket connection manager used by FastAPI app."""
import asyncio
import json
import logging
import time
import traceback
import uuid
from typing import Any, Dict, Optional
from fastapi import WebSocket
from server.services.message_handler import MessageHandler
from server.services.attachment_service import AttachmentService
from server.services.session_execution import SessionExecutionController
from server.services.session_store import WorkflowSessionStore, SessionStatus
from server.services.workflow_run_service import WorkflowRunService
def _json_default(value):
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
try:
return to_dict()
except Exception:
pass
if hasattr(value, "__dict__"):
try:
return vars(value)
except Exception:
pass
return str(value)
def _encode_ws_message(message: Any) -> str:
if isinstance(message, str):
return message
return json.dumps(message, default=_json_default)
class WebSocketManager:
def __init__(
self,
*,
session_store: WorkflowSessionStore | None = None,
session_controller: SessionExecutionController | None = None,
attachment_service: AttachmentService | None = None,
workflow_run_service: WorkflowRunService | None = None,
):
self.active_connections: Dict[str, WebSocket] = {}
self.connection_timestamps: Dict[str, float] = {}
self.session_store = session_store or WorkflowSessionStore()
self.session_controller = session_controller or SessionExecutionController(self.session_store)
self.attachment_service = attachment_service or AttachmentService()
self.workflow_run_service = workflow_run_service or WorkflowRunService(
self.session_store,
self.session_controller,
self.attachment_service,
)
self.message_handler = MessageHandler(
self.session_store,
self.session_controller,
self.workflow_run_service,
)
async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str:
await websocket.accept()
if not session_id:
session_id = str(uuid.uuid4())
self.active_connections[session_id] = websocket
self.connection_timestamps[session_id] = time.time()
logging.info("WebSocket connected: %s", session_id)
await self.send_message(
session_id,
{
"type": "connection",
"data": {"session_id": session_id, "status": "connected"},
},
)
return session_id
def disconnect(self, session_id: str) -> None:
session = self.session_store.get_session(session_id)
if session and session.status in {SessionStatus.RUNNING, SessionStatus.WAITING_FOR_INPUT}:
self.workflow_run_service.request_cancel(
session_id,
reason="WebSocket disconnected",
)
if session_id in self.active_connections:
del self.active_connections[session_id]
if session_id in self.connection_timestamps:
del self.connection_timestamps[session_id]
self.session_controller.cleanup_session(session_id)
remaining_session = self.session_store.get_session(session_id)
if remaining_session and remaining_session.executor is None:
self.session_store.pop_session(session_id)
self.attachment_service.cleanup_session(session_id)
logging.info("WebSocket disconnected: %s", session_id)
async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
if session_id in self.active_connections:
websocket = self.active_connections[session_id]
try:
await websocket.send_text(_encode_ws_message(message))
except Exception as exc:
traceback.print_exc()
logging.error("Failed to send message to %s: %s", session_id, exc)
# self.disconnect(session_id)
def send_message_sync(self, session_id: str, message: Dict[str, Any]) -> None:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
asyncio.create_task(self.send_message(session_id, message))
else:
asyncio.run(self.send_message(session_id, message))
except RuntimeError:
asyncio.run(self.send_message(session_id, message))
async def broadcast(self, message: Dict[str, Any]) -> None:
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
async def handle_heartbeat(self, session_id: str) -> None:
if session_id in self.active_connections:
await self.send_message(
session_id,
{"type": "pong", "data": {"timestamp": time.time()}},
)
else:
logging.warning("Heartbeat request from disconnected session: %s", session_id)
async def handle_message(self, session_id: str, message: str) -> None:
try:
data = json.loads(message)
await self.message_handler.handle_message(session_id, data, self)
except json.JSONDecodeError:
await self.send_message(
session_id,
{"type": "error", "data": {"message": "Invalid JSON format"}},
)
except Exception as exc:
logging.error("Error handling message from %s: %s", session_id, exc)
await self.send_message(
session_id,
{"type": "error", "data": {"message": str(exc)}},
)