mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 11:18:06 +00:00
123 lines
4.3 KiB
Python
Executable File
123 lines
4.3 KiB
Python
Executable File
"""PromptChannel implementation backed by WebSocket sessions."""
|
|
|
|
import asyncio
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from entity.messages import MessageBlock
|
|
from server.services.attachment_service import AttachmentService
|
|
from server.services.session_execution import SessionExecutionController
|
|
from utils.attachments import AttachmentStore
|
|
from utils.exceptions import TimeoutError
|
|
from utils.human_prompt import PromptChannel, PromptResult
|
|
from utils.structured_logger import get_server_logger
|
|
|
|
|
|
class WebPromptChannel(PromptChannel):
|
|
"""Prompt channel that mediates through the WebSocket session controller."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
session_controller: SessionExecutionController,
|
|
websocket_manager: Any,
|
|
attachment_service: AttachmentService,
|
|
attachment_store: AttachmentStore,
|
|
) -> None:
|
|
self.session_id = session_id
|
|
self.session_controller = session_controller
|
|
self.websocket_manager = websocket_manager
|
|
self.attachment_service = attachment_service
|
|
self.attachment_store = attachment_store
|
|
try:
|
|
self._loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
self._loop = None
|
|
|
|
def request(
|
|
self,
|
|
*,
|
|
node_id: str,
|
|
task: str,
|
|
inputs: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> PromptResult:
|
|
preview = inputs or ""
|
|
payload = {
|
|
"input": preview,
|
|
"task_description": task,
|
|
**(metadata or {}),
|
|
}
|
|
|
|
self.session_controller.set_waiting_for_input(
|
|
self.session_id,
|
|
node_id,
|
|
payload,
|
|
)
|
|
|
|
self._notify_human_prompt(node_id, preview, task)
|
|
|
|
try:
|
|
human_response = self.session_controller.wait_for_human_input(self.session_id)
|
|
except TimeoutError:
|
|
raise
|
|
except Exception as exc: # pragma: no cover - propagated upstream
|
|
logger = get_server_logger()
|
|
logger.log_exception(exc, "Error waiting for human input", node_id=node_id, session_id=self.session_id)
|
|
raise
|
|
|
|
response_text, attachment_ids = self._extract_response(human_response)
|
|
blocks = self._build_blocks(response_text, attachment_ids)
|
|
metadata_out = {
|
|
"attachment_count": len(attachment_ids),
|
|
"input_size": len(preview),
|
|
}
|
|
return PromptResult(text=response_text, blocks=blocks, metadata=metadata_out)
|
|
|
|
def _extract_response(self, payload: Any) -> tuple[str, List[str]]:
|
|
if isinstance(payload, dict):
|
|
response_text = payload.get("text") or ""
|
|
attachments = payload.get("attachments") or []
|
|
return response_text, attachments
|
|
if payload is None:
|
|
return "", []
|
|
return str(payload), []
|
|
|
|
def _build_blocks(self, text: str, attachment_ids: List[str]) -> List[MessageBlock]:
|
|
blocks: List[MessageBlock] = []
|
|
if text:
|
|
blocks.append(MessageBlock.text_block(text))
|
|
if attachment_ids:
|
|
blocks.extend(
|
|
self.attachment_service.build_attachment_blocks(
|
|
self.session_id,
|
|
attachment_ids,
|
|
target_store=self.attachment_store,
|
|
)
|
|
)
|
|
if not blocks:
|
|
blocks.append(MessageBlock.text_block(""))
|
|
return blocks
|
|
|
|
def _notify_human_prompt(self, node_id: str, preview: str, task: str) -> None:
|
|
message = {
|
|
"type": "human_input_required",
|
|
"data": {
|
|
"node_id": node_id,
|
|
"input": preview,
|
|
"task_description": task,
|
|
},
|
|
}
|
|
if self._loop and self._loop.is_running():
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self.websocket_manager.send_message(self.session_id, message),
|
|
self._loop,
|
|
)
|
|
try:
|
|
future.result()
|
|
except Exception:
|
|
# fallback to sync send to surface errors/logging
|
|
self.websocket_manager.send_message_sync(self.session_id, message)
|
|
else:
|
|
self.websocket_manager.send_message_sync(self.session_id, message)
|