ChatDev/server/services/batch_run_service.py
2026-01-10 22:35:32 +08:00

256 lines
8.9 KiB
Python

"""Batch workflow execution helpers."""
import asyncio
import csv
import json
import logging
import re
import time
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional
from check.check import load_config
from entity.enums import LogLevel
from entity.graph_config import GraphConfig
from utils.exceptions import ValidationError
from utils.task_input import TaskInputBuilder
from workflow.graph import GraphExecutor
from workflow.graph_context import GraphContext
from server.services.batch_parser import BatchTask
from server.services.workflow_storage import validate_workflow_filename
from server.settings import WARE_HOUSE_DIR, YAML_DIR
class BatchRunService:
"""Runs batch workflows and reports progress over WebSocket."""
def __init__(self) -> None:
self.logger = logging.getLogger(__name__)
async def run_batch(
self,
session_id: str,
yaml_file: str,
tasks: List[BatchTask],
websocket_manager,
*,
max_parallel: int = 5,
file_base: str = "batch",
log_level: Optional[LogLevel] = None,
) -> None:
batch_id = session_id
total = len(tasks)
await websocket_manager.send_message(
session_id,
{"type": "batch_started", "data": {"batch_id": batch_id, "total": total}},
)
semaphore = asyncio.Semaphore(max_parallel)
success_count = 0
failure_count = 0
result_rows: List[Dict[str, Any]] = []
result_lock = asyncio.Lock()
async def run_task(task: BatchTask) -> None:
nonlocal success_count, failure_count
task_id = task.task_id or str(uuid.uuid4())
task_dir = self._sanitize_label(f"{file_base}-{task_id}")
await websocket_manager.send_message(
session_id,
{
"type": "batch_task_started",
"data": {
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
},
},
)
try:
result = await asyncio.to_thread(
self._run_single_task,
session_id,
yaml_file,
task,
task_dir,
log_level,
)
success_count += 1
async with result_lock:
result_rows.append(
{
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
"status": "success",
"duration_ms": result["duration_ms"],
"token_usage": result["token_usage"],
"graph_output": result["graph_output"],
"results": result["results"],
"error": "",
}
)
await websocket_manager.send_message(
session_id,
{
"type": "batch_task_completed",
"data": {
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
"results": result["results"],
"token_usage": result["token_usage"],
"duration_ms": result["duration_ms"],
},
},
)
except Exception as exc:
failure_count += 1
async with result_lock:
result_rows.append(
{
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
"status": "failed",
"duration_ms": None,
"token_usage": None,
"graph_output": "",
"results": None,
"error": str(exc),
}
)
await websocket_manager.send_message(
session_id,
{
"type": "batch_task_failed",
"data": {
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
"error": str(exc),
},
},
)
async def run_with_limit(task: BatchTask) -> None:
async with semaphore:
await run_task(task)
await asyncio.gather(*(run_with_limit(task) for task in tasks))
self._write_batch_outputs(session_id, result_rows)
await websocket_manager.send_message(
session_id,
{
"type": "batch_completed",
"data": {
"batch_id": batch_id,
"total": total,
"succeeded": success_count,
"failed": failure_count,
},
},
)
def _write_batch_outputs(self, session_id: str, result_rows: List[Dict[str, Any]]) -> None:
output_root = WARE_HOUSE_DIR / f"session_{session_id}"
output_root.mkdir(parents=True, exist_ok=True)
csv_path = output_root / "batch_results.csv"
json_path = output_root / "batch_manifest.json"
fieldnames = [
"row_index",
"task_id",
"task_dir",
"status",
"duration_ms",
"token_usage",
"results",
"error",
]
with csv_path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore")
writer.writeheader()
for row in result_rows:
row_copy = dict(row)
row_copy["token_usage"] = json.dumps(row_copy.get("token_usage"))
row_copy["results"] = row_copy.get("graph_output", "")
writer.writerow(row_copy)
with json_path.open("w", encoding="utf-8") as handle:
json.dump(result_rows, handle, ensure_ascii=True, indent=2)
def _run_single_task(
self,
session_id: str,
yaml_file: str,
task: BatchTask,
task_dir: str,
log_level: Optional[LogLevel],
) -> Dict[str, Any]:
yaml_path = self._resolve_yaml_path(yaml_file)
design = load_config(yaml_path, vars_override=task.vars_override or None)
if any(node.type == "human" for node in design.graph.nodes):
raise ValidationError(
"Batch execution does not support human nodes",
details={"yaml_file": yaml_file},
)
output_root = WARE_HOUSE_DIR / f"session_{session_id}"
graph_config = GraphConfig.from_definition(
design.graph,
name=task_dir,
output_root=output_root,
source_path=str(yaml_path),
vars=design.vars,
)
graph_config.metadata["fixed_output_dir"] = True
if log_level:
graph_config.log_level = log_level
graph_config.definition.log_level = log_level
graph_context = GraphContext(config=graph_config)
start_time = time.perf_counter()
executor = GraphExecutor(graph_context, session_id=session_id)
task_input = self._build_task_input(executor.attachment_store, task)
executor._execute(task_input)
duration_ms = int((time.perf_counter() - start_time) * 1000)
return {
"results": executor.outputs,
"token_usage": executor.token_tracker.get_token_usage(),
"duration_ms": duration_ms,
"graph_output": executor.get_final_output(),
}
@staticmethod
def _build_task_input(attachment_store, task: BatchTask):
if task.attachment_paths:
builder = TaskInputBuilder(attachment_store)
return builder.build_from_file_paths(task.task_prompt, task.attachment_paths)
return task.task_prompt
@staticmethod
def _sanitize_label(value: str) -> str:
cleaned = re.sub(r"[^a-zA-Z0-9._-]+", "_", value)
return cleaned.strip("_") or "task"
@staticmethod
def _resolve_yaml_path(yaml_filename: str) -> Path:
safe_name = validate_workflow_filename(yaml_filename, require_yaml_extension=True)
yaml_path = YAML_DIR / safe_name
if not yaml_path.exists():
raise ValidationError("YAML file not found", details={"yaml_file": safe_name})
return yaml_path