mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 19:28:09 +00:00
256 lines
8.9 KiB
Python
256 lines
8.9 KiB
Python
"""Batch workflow execution helpers."""
|
|
|
|
import asyncio
|
|
import csv
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from check.check import load_config
|
|
from entity.enums import LogLevel
|
|
from entity.graph_config import GraphConfig
|
|
from utils.exceptions import ValidationError
|
|
from utils.task_input import TaskInputBuilder
|
|
from workflow.graph import GraphExecutor
|
|
from workflow.graph_context import GraphContext
|
|
|
|
from server.services.batch_parser import BatchTask
|
|
from server.services.workflow_storage import validate_workflow_filename
|
|
from server.settings import WARE_HOUSE_DIR, YAML_DIR
|
|
|
|
|
|
class BatchRunService:
|
|
"""Runs batch workflows and reports progress over WebSocket."""
|
|
|
|
def __init__(self) -> None:
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
async def run_batch(
|
|
self,
|
|
session_id: str,
|
|
yaml_file: str,
|
|
tasks: List[BatchTask],
|
|
websocket_manager,
|
|
*,
|
|
max_parallel: int = 5,
|
|
file_base: str = "batch",
|
|
log_level: Optional[LogLevel] = None,
|
|
) -> None:
|
|
batch_id = session_id
|
|
total = len(tasks)
|
|
|
|
await websocket_manager.send_message(
|
|
session_id,
|
|
{"type": "batch_started", "data": {"batch_id": batch_id, "total": total}},
|
|
)
|
|
|
|
semaphore = asyncio.Semaphore(max_parallel)
|
|
success_count = 0
|
|
failure_count = 0
|
|
result_rows: List[Dict[str, Any]] = []
|
|
result_lock = asyncio.Lock()
|
|
|
|
async def run_task(task: BatchTask) -> None:
|
|
nonlocal success_count, failure_count
|
|
task_id = task.task_id or str(uuid.uuid4())
|
|
task_dir = self._sanitize_label(f"{file_base}-{task_id}")
|
|
|
|
await websocket_manager.send_message(
|
|
session_id,
|
|
{
|
|
"type": "batch_task_started",
|
|
"data": {
|
|
"row_index": task.row_index,
|
|
"task_id": task_id,
|
|
"task_dir": task_dir,
|
|
},
|
|
},
|
|
)
|
|
|
|
try:
|
|
result = await asyncio.to_thread(
|
|
self._run_single_task,
|
|
session_id,
|
|
yaml_file,
|
|
task,
|
|
task_dir,
|
|
log_level,
|
|
)
|
|
success_count += 1
|
|
async with result_lock:
|
|
result_rows.append(
|
|
{
|
|
"row_index": task.row_index,
|
|
"task_id": task_id,
|
|
"task_dir": task_dir,
|
|
"status": "success",
|
|
"duration_ms": result["duration_ms"],
|
|
"token_usage": result["token_usage"],
|
|
"graph_output": result["graph_output"],
|
|
"results": result["results"],
|
|
"error": "",
|
|
}
|
|
)
|
|
await websocket_manager.send_message(
|
|
session_id,
|
|
{
|
|
"type": "batch_task_completed",
|
|
"data": {
|
|
"row_index": task.row_index,
|
|
"task_id": task_id,
|
|
"task_dir": task_dir,
|
|
"results": result["results"],
|
|
"token_usage": result["token_usage"],
|
|
"duration_ms": result["duration_ms"],
|
|
},
|
|
},
|
|
)
|
|
except Exception as exc:
|
|
failure_count += 1
|
|
async with result_lock:
|
|
result_rows.append(
|
|
{
|
|
"row_index": task.row_index,
|
|
"task_id": task_id,
|
|
"task_dir": task_dir,
|
|
"status": "failed",
|
|
"duration_ms": None,
|
|
"token_usage": None,
|
|
"graph_output": "",
|
|
"results": None,
|
|
"error": str(exc),
|
|
}
|
|
)
|
|
await websocket_manager.send_message(
|
|
session_id,
|
|
{
|
|
"type": "batch_task_failed",
|
|
"data": {
|
|
"row_index": task.row_index,
|
|
"task_id": task_id,
|
|
"task_dir": task_dir,
|
|
"error": str(exc),
|
|
},
|
|
},
|
|
)
|
|
|
|
async def run_with_limit(task: BatchTask) -> None:
|
|
async with semaphore:
|
|
await run_task(task)
|
|
|
|
await asyncio.gather(*(run_with_limit(task) for task in tasks))
|
|
|
|
self._write_batch_outputs(session_id, result_rows)
|
|
|
|
await websocket_manager.send_message(
|
|
session_id,
|
|
{
|
|
"type": "batch_completed",
|
|
"data": {
|
|
"batch_id": batch_id,
|
|
"total": total,
|
|
"succeeded": success_count,
|
|
"failed": failure_count,
|
|
},
|
|
},
|
|
)
|
|
|
|
def _write_batch_outputs(self, session_id: str, result_rows: List[Dict[str, Any]]) -> None:
|
|
output_root = WARE_HOUSE_DIR / f"session_{session_id}"
|
|
output_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
csv_path = output_root / "batch_results.csv"
|
|
json_path = output_root / "batch_manifest.json"
|
|
|
|
fieldnames = [
|
|
"row_index",
|
|
"task_id",
|
|
"task_dir",
|
|
"status",
|
|
"duration_ms",
|
|
"token_usage",
|
|
"results",
|
|
"error",
|
|
]
|
|
|
|
with csv_path.open("w", newline="", encoding="utf-8") as handle:
|
|
writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore")
|
|
writer.writeheader()
|
|
for row in result_rows:
|
|
row_copy = dict(row)
|
|
row_copy["token_usage"] = json.dumps(row_copy.get("token_usage"))
|
|
row_copy["results"] = row_copy.get("graph_output", "")
|
|
writer.writerow(row_copy)
|
|
|
|
with json_path.open("w", encoding="utf-8") as handle:
|
|
json.dump(result_rows, handle, ensure_ascii=True, indent=2)
|
|
|
|
def _run_single_task(
|
|
self,
|
|
session_id: str,
|
|
yaml_file: str,
|
|
task: BatchTask,
|
|
task_dir: str,
|
|
log_level: Optional[LogLevel],
|
|
) -> Dict[str, Any]:
|
|
yaml_path = self._resolve_yaml_path(yaml_file)
|
|
design = load_config(yaml_path, vars_override=task.vars_override or None)
|
|
if any(node.type == "human" for node in design.graph.nodes):
|
|
raise ValidationError(
|
|
"Batch execution does not support human nodes",
|
|
details={"yaml_file": yaml_file},
|
|
)
|
|
|
|
output_root = WARE_HOUSE_DIR / f"session_{session_id}"
|
|
graph_config = GraphConfig.from_definition(
|
|
design.graph,
|
|
name=task_dir,
|
|
output_root=output_root,
|
|
source_path=str(yaml_path),
|
|
vars=design.vars,
|
|
)
|
|
graph_config.metadata["fixed_output_dir"] = True
|
|
|
|
if log_level:
|
|
graph_config.log_level = log_level
|
|
graph_config.definition.log_level = log_level
|
|
|
|
graph_context = GraphContext(config=graph_config)
|
|
|
|
start_time = time.perf_counter()
|
|
executor = GraphExecutor(graph_context, session_id=session_id)
|
|
task_input = self._build_task_input(executor.attachment_store, task)
|
|
executor._execute(task_input)
|
|
duration_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
|
return {
|
|
"results": executor.outputs,
|
|
"token_usage": executor.token_tracker.get_token_usage(),
|
|
"duration_ms": duration_ms,
|
|
"graph_output": executor.get_final_output(),
|
|
}
|
|
|
|
@staticmethod
|
|
def _build_task_input(attachment_store, task: BatchTask):
|
|
if task.attachment_paths:
|
|
builder = TaskInputBuilder(attachment_store)
|
|
return builder.build_from_file_paths(task.task_prompt, task.attachment_paths)
|
|
return task.task_prompt
|
|
|
|
@staticmethod
|
|
def _sanitize_label(value: str) -> str:
|
|
cleaned = re.sub(r"[^a-zA-Z0-9._-]+", "_", value)
|
|
return cleaned.strip("_") or "task"
|
|
|
|
@staticmethod
|
|
def _resolve_yaml_path(yaml_filename: str) -> Path:
|
|
safe_name = validate_workflow_filename(yaml_filename, require_yaml_extension=True)
|
|
yaml_path = YAML_DIR / safe_name
|
|
if not yaml_path.exists():
|
|
raise ValidationError("YAML file not found", details={"yaml_file": safe_name})
|
|
return yaml_path
|