ChatDev/server/services/prompt_channel.py
2026-01-07 16:24:01 +08:00

123 lines
4.3 KiB
Python
Executable File

"""PromptChannel implementation backed by WebSocket sessions."""
import asyncio
from typing import Any, Dict, List, Optional
from entity.messages import MessageBlock
from server.services.attachment_service import AttachmentService
from server.services.session_execution import SessionExecutionController
from utils.attachments import AttachmentStore
from utils.exceptions import TimeoutError
from utils.human_prompt import PromptChannel, PromptResult
from utils.structured_logger import get_server_logger
class WebPromptChannel(PromptChannel):
"""Prompt channel that mediates through the WebSocket session controller."""
def __init__(
self,
*,
session_id: str,
session_controller: SessionExecutionController,
websocket_manager: Any,
attachment_service: AttachmentService,
attachment_store: AttachmentStore,
) -> None:
self.session_id = session_id
self.session_controller = session_controller
self.websocket_manager = websocket_manager
self.attachment_service = attachment_service
self.attachment_store = attachment_store
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = None
def request(
self,
*,
node_id: str,
task: str,
inputs: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> PromptResult:
preview = inputs or ""
payload = {
"input": preview,
"task_description": task,
**(metadata or {}),
}
self.session_controller.set_waiting_for_input(
self.session_id,
node_id,
payload,
)
self._notify_human_prompt(node_id, preview, task)
try:
human_response = self.session_controller.wait_for_human_input(self.session_id)
except TimeoutError:
raise
except Exception as exc: # pragma: no cover - propagated upstream
logger = get_server_logger()
logger.log_exception(exc, "Error waiting for human input", node_id=node_id, session_id=self.session_id)
raise
response_text, attachment_ids = self._extract_response(human_response)
blocks = self._build_blocks(response_text, attachment_ids)
metadata_out = {
"attachment_count": len(attachment_ids),
"input_size": len(preview),
}
return PromptResult(text=response_text, blocks=blocks, metadata=metadata_out)
def _extract_response(self, payload: Any) -> tuple[str, List[str]]:
if isinstance(payload, dict):
response_text = payload.get("text") or ""
attachments = payload.get("attachments") or []
return response_text, attachments
if payload is None:
return "", []
return str(payload), []
def _build_blocks(self, text: str, attachment_ids: List[str]) -> List[MessageBlock]:
blocks: List[MessageBlock] = []
if text:
blocks.append(MessageBlock.text_block(text))
if attachment_ids:
blocks.extend(
self.attachment_service.build_attachment_blocks(
self.session_id,
attachment_ids,
target_store=self.attachment_store,
)
)
if not blocks:
blocks.append(MessageBlock.text_block(""))
return blocks
def _notify_human_prompt(self, node_id: str, preview: str, task: str) -> None:
message = {
"type": "human_input_required",
"data": {
"node_id": node_id,
"input": preview,
"task_description": task,
},
}
if self._loop and self._loop.is_running():
future = asyncio.run_coroutine_threadsafe(
self.websocket_manager.send_message(self.session_id, message),
self._loop,
)
try:
future.result()
except Exception:
# fallback to sync send to surface errors/logging
self.websocket_manager.send_message_sync(self.session_id, message)
else:
self.websocket_manager.send_message_sync(self.session_id, message)