add: batch execution API added

This commit is contained in:
Shu Yao 2026-01-10 22:35:32 +08:00
parent 5cf51c768f
commit 7bb72f56e6
7 changed files with 536 additions and 2 deletions

View File

@ -30,6 +30,7 @@ dependencies = [
"networkx",
"cartopy",
"pandas>=2.3.3",
"openpyxl>=3.1.2",
"numpy>=2.3.5",
"seaborn>=0.13.2",
"google-genai>=1.52.0",

View File

@ -1,6 +1,6 @@
"""Aggregates API routers."""
from . import artifacts, execute, health, sessions, uploads, vuegraphs, workflows, websocket
from . import artifacts, batch, execute, health, sessions, uploads, vuegraphs, workflows, websocket
ALL_ROUTERS = [
health.router,
@ -9,6 +9,7 @@ ALL_ROUTERS = [
uploads.router,
artifacts.router,
sessions.router,
batch.router,
execute.router,
websocket.router,
]

61
server/routes/batch.py Normal file
View File

@ -0,0 +1,61 @@
import asyncio
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from entity.enums import LogLevel
from server.services.batch_parser import parse_batch_file
from server.services.batch_run_service import BatchRunService
from server.state import ensure_known_session
from utils.exceptions import ValidationError
router = APIRouter()
@router.post("/api/workflows/batch")
async def execute_batch(
file: UploadFile = File(...),
session_id: str = Form(...),
yaml_file: str = Form(...),
max_parallel: int = Form(5),
log_level: str | None = Form(None),
):
try:
manager = ensure_known_session(session_id, require_connection=True)
except ValidationError as exc:
raise HTTPException(status_code=400, detail=str(exc))
if max_parallel < 1:
raise HTTPException(status_code=400, detail="max_parallel must be >= 1")
try:
content = await file.read()
tasks, file_base = parse_batch_file(content, file.filename or "batch.csv")
except ValidationError as exc:
raise HTTPException(status_code=400, detail=str(exc))
resolved_level = None
if log_level:
try:
resolved_level = LogLevel(log_level)
except ValueError:
raise HTTPException(status_code=400, detail="log_level must be either DEBUG or INFO")
service = BatchRunService()
asyncio.create_task(
service.run_batch(
session_id,
yaml_file,
tasks,
manager,
max_parallel=max_parallel,
file_base=file_base,
log_level=resolved_level,
)
)
return {
"status": "accepted",
"session_id": session_id,
"batch_id": session_id,
"task_count": len(tasks),
}

View File

@ -0,0 +1,192 @@
"""Parse batch task files (CSV/Excel) into runnable tasks."""
import json
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import pandas as pd
from utils.exceptions import ValidationError
@dataclass(frozen=True)
class BatchTask:
row_index: int
task_id: Optional[str]
task_prompt: str
attachment_paths: List[str]
vars_override: Dict[str, Any]
def parse_batch_file(content: bytes, filename: str) -> Tuple[List[BatchTask], str]:
"""Parse a CSV/Excel batch file and return tasks plus file base name."""
suffix = Path(filename or "").suffix.lower()
if suffix not in {".csv", ".xlsx", ".xls"}:
raise ValidationError("Unsupported file type; must be .csv or .xlsx/.xls", field="file")
if suffix == ".csv":
df = _read_csv(content)
else:
df = _read_excel(content)
file_base = Path(filename).stem or "batch"
tasks = _parse_dataframe(df)
if not tasks:
raise ValidationError("Batch file contains no tasks", field="file")
return tasks, file_base
def _read_csv(content: bytes) -> pd.DataFrame:
try:
import chardet
except Exception:
chardet = None
encoding = "utf-8"
if chardet:
detected = chardet.detect(content)
encoding = detected.get("encoding") or encoding
try:
return pd.read_csv(BytesIO(content), encoding=encoding)
except Exception as exc:
raise ValidationError(f"Failed to read CSV: {exc}", field="file")
def _read_excel(content: bytes) -> pd.DataFrame:
try:
return pd.read_excel(BytesIO(content))
except Exception as exc:
raise ValidationError(f"Failed to read Excel file: {exc}", field="file")
def _parse_dataframe(df: pd.DataFrame) -> List[BatchTask]:
column_map = {str(col).strip().lower(): col for col in df.columns}
id_col = column_map.get("id")
task_col = column_map.get("task")
attachments_col = column_map.get("attachments")
vars_col = column_map.get("vars")
tasks: List[BatchTask] = []
seen_ids: set[str] = set()
for row_index, row in enumerate(df.to_dict(orient="records"), start=1):
task_prompt = _get_cell_text(row, task_col)
attachment_paths = _parse_json_list(row, attachments_col, row_index)
vars_override = _parse_json_dict(row, vars_col, row_index)
if not task_prompt and not attachment_paths:
raise ValidationError(
"Task and attachments cannot both be empty",
details={"row_index": row_index},
)
task_id = _get_cell_text(row, id_col)
if task_id:
if task_id in seen_ids:
raise ValidationError(
"Duplicate ID in batch file",
details={"row_index": row_index, "task_id": task_id},
)
seen_ids.add(task_id)
tasks.append(
BatchTask(
row_index=row_index,
task_id=task_id or None,
task_prompt=task_prompt,
attachment_paths=attachment_paths,
vars_override=vars_override,
)
)
return tasks
def _get_cell_text(row: Dict[str, Any], column: Optional[str]) -> str:
if not column:
return ""
value = row.get(column)
if value is None:
return ""
if isinstance(value, float) and pd.isna(value):
return ""
if pd.isna(value):
return ""
return str(value).strip()
def _parse_json_list(
row: Dict[str, Any],
column: Optional[str],
row_index: int,
) -> List[str]:
if not column:
return []
raw_value = row.get(column)
if raw_value is None or (isinstance(raw_value, float) and pd.isna(raw_value)):
return []
if isinstance(raw_value, list):
return _ensure_string_list(raw_value, row_index, "Attachments")
if isinstance(raw_value, str):
if not raw_value.strip():
return []
try:
parsed = json.loads(raw_value)
except json.JSONDecodeError as exc:
raise ValidationError(
f"Invalid JSON in Attachments: {exc}",
details={"row_index": row_index},
)
return _ensure_string_list(parsed, row_index, "Attachments")
raise ValidationError(
"Attachments must be a JSON list",
details={"row_index": row_index},
)
def _parse_json_dict(
row: Dict[str, Any],
column: Optional[str],
row_index: int,
) -> Dict[str, Any]:
if not column:
return {}
raw_value = row.get(column)
if raw_value is None or (isinstance(raw_value, float) and pd.isna(raw_value)):
return {}
if isinstance(raw_value, dict):
return raw_value
if isinstance(raw_value, str):
if not raw_value.strip():
return {}
try:
parsed = json.loads(raw_value)
except json.JSONDecodeError as exc:
raise ValidationError(
f"Invalid JSON in Vars: {exc}",
details={"row_index": row_index},
)
if not isinstance(parsed, dict):
raise ValidationError(
"Vars must be a JSON object",
details={"row_index": row_index},
)
return parsed
raise ValidationError(
"Vars must be a JSON object",
details={"row_index": row_index},
)
def _ensure_string_list(value: Any, row_index: int, field: str) -> List[str]:
if not isinstance(value, list):
raise ValidationError(
f"{field} must be a JSON list",
details={"row_index": row_index},
)
result: List[str] = []
for item in value:
if item is None or (isinstance(item, float) and pd.isna(item)):
continue
result.append(str(item))
return result

View File

@ -0,0 +1,255 @@
"""Batch workflow execution helpers."""
import asyncio
import csv
import json
import logging
import re
import time
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional
from check.check import load_config
from entity.enums import LogLevel
from entity.graph_config import GraphConfig
from utils.exceptions import ValidationError
from utils.task_input import TaskInputBuilder
from workflow.graph import GraphExecutor
from workflow.graph_context import GraphContext
from server.services.batch_parser import BatchTask
from server.services.workflow_storage import validate_workflow_filename
from server.settings import WARE_HOUSE_DIR, YAML_DIR
class BatchRunService:
"""Runs batch workflows and reports progress over WebSocket."""
def __init__(self) -> None:
self.logger = logging.getLogger(__name__)
async def run_batch(
self,
session_id: str,
yaml_file: str,
tasks: List[BatchTask],
websocket_manager,
*,
max_parallel: int = 5,
file_base: str = "batch",
log_level: Optional[LogLevel] = None,
) -> None:
batch_id = session_id
total = len(tasks)
await websocket_manager.send_message(
session_id,
{"type": "batch_started", "data": {"batch_id": batch_id, "total": total}},
)
semaphore = asyncio.Semaphore(max_parallel)
success_count = 0
failure_count = 0
result_rows: List[Dict[str, Any]] = []
result_lock = asyncio.Lock()
async def run_task(task: BatchTask) -> None:
nonlocal success_count, failure_count
task_id = task.task_id or str(uuid.uuid4())
task_dir = self._sanitize_label(f"{file_base}-{task_id}")
await websocket_manager.send_message(
session_id,
{
"type": "batch_task_started",
"data": {
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
},
},
)
try:
result = await asyncio.to_thread(
self._run_single_task,
session_id,
yaml_file,
task,
task_dir,
log_level,
)
success_count += 1
async with result_lock:
result_rows.append(
{
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
"status": "success",
"duration_ms": result["duration_ms"],
"token_usage": result["token_usage"],
"graph_output": result["graph_output"],
"results": result["results"],
"error": "",
}
)
await websocket_manager.send_message(
session_id,
{
"type": "batch_task_completed",
"data": {
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
"results": result["results"],
"token_usage": result["token_usage"],
"duration_ms": result["duration_ms"],
},
},
)
except Exception as exc:
failure_count += 1
async with result_lock:
result_rows.append(
{
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
"status": "failed",
"duration_ms": None,
"token_usage": None,
"graph_output": "",
"results": None,
"error": str(exc),
}
)
await websocket_manager.send_message(
session_id,
{
"type": "batch_task_failed",
"data": {
"row_index": task.row_index,
"task_id": task_id,
"task_dir": task_dir,
"error": str(exc),
},
},
)
async def run_with_limit(task: BatchTask) -> None:
async with semaphore:
await run_task(task)
await asyncio.gather(*(run_with_limit(task) for task in tasks))
self._write_batch_outputs(session_id, result_rows)
await websocket_manager.send_message(
session_id,
{
"type": "batch_completed",
"data": {
"batch_id": batch_id,
"total": total,
"succeeded": success_count,
"failed": failure_count,
},
},
)
def _write_batch_outputs(self, session_id: str, result_rows: List[Dict[str, Any]]) -> None:
output_root = WARE_HOUSE_DIR / f"session_{session_id}"
output_root.mkdir(parents=True, exist_ok=True)
csv_path = output_root / "batch_results.csv"
json_path = output_root / "batch_manifest.json"
fieldnames = [
"row_index",
"task_id",
"task_dir",
"status",
"duration_ms",
"token_usage",
"results",
"error",
]
with csv_path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore")
writer.writeheader()
for row in result_rows:
row_copy = dict(row)
row_copy["token_usage"] = json.dumps(row_copy.get("token_usage"))
row_copy["results"] = row_copy.get("graph_output", "")
writer.writerow(row_copy)
with json_path.open("w", encoding="utf-8") as handle:
json.dump(result_rows, handle, ensure_ascii=True, indent=2)
def _run_single_task(
self,
session_id: str,
yaml_file: str,
task: BatchTask,
task_dir: str,
log_level: Optional[LogLevel],
) -> Dict[str, Any]:
yaml_path = self._resolve_yaml_path(yaml_file)
design = load_config(yaml_path, vars_override=task.vars_override or None)
if any(node.type == "human" for node in design.graph.nodes):
raise ValidationError(
"Batch execution does not support human nodes",
details={"yaml_file": yaml_file},
)
output_root = WARE_HOUSE_DIR / f"session_{session_id}"
graph_config = GraphConfig.from_definition(
design.graph,
name=task_dir,
output_root=output_root,
source_path=str(yaml_path),
vars=design.vars,
)
graph_config.metadata["fixed_output_dir"] = True
if log_level:
graph_config.log_level = log_level
graph_config.definition.log_level = log_level
graph_context = GraphContext(config=graph_config)
start_time = time.perf_counter()
executor = GraphExecutor(graph_context, session_id=session_id)
task_input = self._build_task_input(executor.attachment_store, task)
executor._execute(task_input)
duration_ms = int((time.perf_counter() - start_time) * 1000)
return {
"results": executor.outputs,
"token_usage": executor.token_tracker.get_token_usage(),
"duration_ms": duration_ms,
"graph_output": executor.get_final_output(),
}
@staticmethod
def _build_task_input(attachment_store, task: BatchTask):
if task.attachment_paths:
builder = TaskInputBuilder(attachment_store)
return builder.build_from_file_paths(task.task_prompt, task.attachment_paths)
return task.task_prompt
@staticmethod
def _sanitize_label(value: str) -> str:
cleaned = re.sub(r"[^a-zA-Z0-9._-]+", "_", value)
return cleaned.strip("_") or "task"
@staticmethod
def _resolve_yaml_path(yaml_filename: str) -> Path:
safe_name = validate_workflow_filename(yaml_filename, require_yaml_extension=True)
yaml_path = YAML_DIR / safe_name
if not yaml_path.exists():
raise ValidationError("YAML file not found", details={"yaml_file": safe_name})
return yaml_path

23
uv.lock generated
View File

@ -390,6 +390,7 @@ dependencies = [
{ name = "networkx" },
{ name = "numpy" },
{ name = "openai" },
{ name = "openpyxl" },
{ name = "pandas" },
{ name = "pydantic" },
{ name = "pygame" },
@ -422,6 +423,7 @@ requires-dist = [
{ name = "networkx" },
{ name = "numpy", specifier = ">=2.3.5" },
{ name = "openai" },
{ name = "openpyxl", specifier = ">=3.1.2" },
{ name = "pandas", specifier = ">=2.3.3" },
{ name = "pydantic", specifier = "==2.12.5" },
{ name = "pygame", specifier = ">=2.6.1" },
@ -494,6 +496,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
]
[[package]]
name = "et-xmlfile"
version = "2.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d3/38/af70d7ab1ae9d4da450eeec1fa3918940a5fafb9055e934af8d6eb0c2313/et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54", size = 17234, upload-time = "2024-10-25T17:25:40.039Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa", size = 18059, upload-time = "2024-10-25T17:25:39.051Z" },
]
[[package]]
name = "exceptiongroup"
version = "1.3.1"
@ -1147,6 +1158,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/12/cf/03675d8bd8ecbf4445504d8071adab19f5f993676795708e36402ab38263/openapi_pydantic-0.5.1-py3-none-any.whl", hash = "sha256:a3a09ef4586f5bd760a8df7f43028b60cafb6d9f61de2acba9574766255ab146", size = 96381, upload-time = "2025-01-08T19:29:25.275Z" },
]
[[package]]
name = "openpyxl"
version = "3.1.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "et-xmlfile" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3d/f9/88d94a75de065ea32619465d2f77b29a0469500e99012523b91cc4141cd1/openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050", size = 186464, upload-time = "2024-06-28T14:03:44.161Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2", size = 250910, upload-time = "2024-06-28T14:03:41.161Z" },
]
[[package]]
name = "opentelemetry-api"
version = "1.39.1"

View File

@ -61,7 +61,8 @@ class GraphContext:
# Output directory
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
if "session_" in config.name:
fixed_output_dir = bool(config.metadata.get("fixed_output_dir"))
if fixed_output_dir or "session_" in config.name:
self.directory = config.output_root / config.name
else:
self.directory = config.output_root / f"{config.name}_{timestamp}"