diff --git a/pyproject.toml b/pyproject.toml index 4d66c797..e1dad706 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "networkx", "cartopy", "pandas>=2.3.3", + "openpyxl>=3.1.2", "numpy>=2.3.5", "seaborn>=0.13.2", "google-genai>=1.52.0", diff --git a/server/routes/__init__.py b/server/routes/__init__.py index 464fcab5..32706ea7 100755 --- a/server/routes/__init__.py +++ b/server/routes/__init__.py @@ -1,6 +1,6 @@ """Aggregates API routers.""" -from . import artifacts, execute, health, sessions, uploads, vuegraphs, workflows, websocket +from . import artifacts, batch, execute, health, sessions, uploads, vuegraphs, workflows, websocket ALL_ROUTERS = [ health.router, @@ -9,6 +9,7 @@ ALL_ROUTERS = [ uploads.router, artifacts.router, sessions.router, + batch.router, execute.router, websocket.router, ] diff --git a/server/routes/batch.py b/server/routes/batch.py new file mode 100644 index 00000000..819aa165 --- /dev/null +++ b/server/routes/batch.py @@ -0,0 +1,61 @@ +import asyncio + +from fastapi import APIRouter, File, Form, HTTPException, UploadFile + +from entity.enums import LogLevel +from server.services.batch_parser import parse_batch_file +from server.services.batch_run_service import BatchRunService +from server.state import ensure_known_session +from utils.exceptions import ValidationError + +router = APIRouter() + + +@router.post("/api/workflows/batch") +async def execute_batch( + file: UploadFile = File(...), + session_id: str = Form(...), + yaml_file: str = Form(...), + max_parallel: int = Form(5), + log_level: str | None = Form(None), +): + try: + manager = ensure_known_session(session_id, require_connection=True) + except ValidationError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + if max_parallel < 1: + raise HTTPException(status_code=400, detail="max_parallel must be >= 1") + + try: + content = await file.read() + tasks, file_base = parse_batch_file(content, file.filename or "batch.csv") + except ValidationError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + resolved_level = None + if log_level: + try: + resolved_level = LogLevel(log_level) + except ValueError: + raise HTTPException(status_code=400, detail="log_level must be either DEBUG or INFO") + + service = BatchRunService() + asyncio.create_task( + service.run_batch( + session_id, + yaml_file, + tasks, + manager, + max_parallel=max_parallel, + file_base=file_base, + log_level=resolved_level, + ) + ) + + return { + "status": "accepted", + "session_id": session_id, + "batch_id": session_id, + "task_count": len(tasks), + } diff --git a/server/services/batch_parser.py b/server/services/batch_parser.py new file mode 100644 index 00000000..75218ac7 --- /dev/null +++ b/server/services/batch_parser.py @@ -0,0 +1,192 @@ +"""Parse batch task files (CSV/Excel) into runnable tasks.""" + +import json +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd + +from utils.exceptions import ValidationError + + +@dataclass(frozen=True) +class BatchTask: + row_index: int + task_id: Optional[str] + task_prompt: str + attachment_paths: List[str] + vars_override: Dict[str, Any] + + +def parse_batch_file(content: bytes, filename: str) -> Tuple[List[BatchTask], str]: + """Parse a CSV/Excel batch file and return tasks plus file base name.""" + suffix = Path(filename or "").suffix.lower() + if suffix not in {".csv", ".xlsx", ".xls"}: + raise ValidationError("Unsupported file type; must be .csv or .xlsx/.xls", field="file") + + if suffix == ".csv": + df = _read_csv(content) + else: + df = _read_excel(content) + + file_base = Path(filename).stem or "batch" + tasks = _parse_dataframe(df) + if not tasks: + raise ValidationError("Batch file contains no tasks", field="file") + return tasks, file_base + + +def _read_csv(content: bytes) -> pd.DataFrame: + try: + import chardet + except Exception: + chardet = None + encoding = "utf-8" + if chardet: + detected = chardet.detect(content) + encoding = detected.get("encoding") or encoding + try: + return pd.read_csv(BytesIO(content), encoding=encoding) + except Exception as exc: + raise ValidationError(f"Failed to read CSV: {exc}", field="file") + + +def _read_excel(content: bytes) -> pd.DataFrame: + try: + return pd.read_excel(BytesIO(content)) + except Exception as exc: + raise ValidationError(f"Failed to read Excel file: {exc}", field="file") + + +def _parse_dataframe(df: pd.DataFrame) -> List[BatchTask]: + column_map = {str(col).strip().lower(): col for col in df.columns} + id_col = column_map.get("id") + task_col = column_map.get("task") + attachments_col = column_map.get("attachments") + vars_col = column_map.get("vars") + + tasks: List[BatchTask] = [] + seen_ids: set[str] = set() + + for row_index, row in enumerate(df.to_dict(orient="records"), start=1): + task_prompt = _get_cell_text(row, task_col) + attachment_paths = _parse_json_list(row, attachments_col, row_index) + vars_override = _parse_json_dict(row, vars_col, row_index) + + if not task_prompt and not attachment_paths: + raise ValidationError( + "Task and attachments cannot both be empty", + details={"row_index": row_index}, + ) + + task_id = _get_cell_text(row, id_col) + if task_id: + if task_id in seen_ids: + raise ValidationError( + "Duplicate ID in batch file", + details={"row_index": row_index, "task_id": task_id}, + ) + seen_ids.add(task_id) + + tasks.append( + BatchTask( + row_index=row_index, + task_id=task_id or None, + task_prompt=task_prompt, + attachment_paths=attachment_paths, + vars_override=vars_override, + ) + ) + return tasks + + +def _get_cell_text(row: Dict[str, Any], column: Optional[str]) -> str: + if not column: + return "" + value = row.get(column) + if value is None: + return "" + if isinstance(value, float) and pd.isna(value): + return "" + if pd.isna(value): + return "" + return str(value).strip() + + +def _parse_json_list( + row: Dict[str, Any], + column: Optional[str], + row_index: int, +) -> List[str]: + if not column: + return [] + raw_value = row.get(column) + if raw_value is None or (isinstance(raw_value, float) and pd.isna(raw_value)): + return [] + if isinstance(raw_value, list): + return _ensure_string_list(raw_value, row_index, "Attachments") + if isinstance(raw_value, str): + if not raw_value.strip(): + return [] + try: + parsed = json.loads(raw_value) + except json.JSONDecodeError as exc: + raise ValidationError( + f"Invalid JSON in Attachments: {exc}", + details={"row_index": row_index}, + ) + return _ensure_string_list(parsed, row_index, "Attachments") + raise ValidationError( + "Attachments must be a JSON list", + details={"row_index": row_index}, + ) + + +def _parse_json_dict( + row: Dict[str, Any], + column: Optional[str], + row_index: int, +) -> Dict[str, Any]: + if not column: + return {} + raw_value = row.get(column) + if raw_value is None or (isinstance(raw_value, float) and pd.isna(raw_value)): + return {} + if isinstance(raw_value, dict): + return raw_value + if isinstance(raw_value, str): + if not raw_value.strip(): + return {} + try: + parsed = json.loads(raw_value) + except json.JSONDecodeError as exc: + raise ValidationError( + f"Invalid JSON in Vars: {exc}", + details={"row_index": row_index}, + ) + if not isinstance(parsed, dict): + raise ValidationError( + "Vars must be a JSON object", + details={"row_index": row_index}, + ) + return parsed + raise ValidationError( + "Vars must be a JSON object", + details={"row_index": row_index}, + ) + + +def _ensure_string_list(value: Any, row_index: int, field: str) -> List[str]: + if not isinstance(value, list): + raise ValidationError( + f"{field} must be a JSON list", + details={"row_index": row_index}, + ) + result: List[str] = [] + for item in value: + if item is None or (isinstance(item, float) and pd.isna(item)): + continue + result.append(str(item)) + return result diff --git a/server/services/batch_run_service.py b/server/services/batch_run_service.py new file mode 100644 index 00000000..649fda77 --- /dev/null +++ b/server/services/batch_run_service.py @@ -0,0 +1,255 @@ +"""Batch workflow execution helpers.""" + +import asyncio +import csv +import json +import logging +import re +import time +import uuid +from pathlib import Path +from typing import Any, Dict, List, Optional + +from check.check import load_config +from entity.enums import LogLevel +from entity.graph_config import GraphConfig +from utils.exceptions import ValidationError +from utils.task_input import TaskInputBuilder +from workflow.graph import GraphExecutor +from workflow.graph_context import GraphContext + +from server.services.batch_parser import BatchTask +from server.services.workflow_storage import validate_workflow_filename +from server.settings import WARE_HOUSE_DIR, YAML_DIR + + +class BatchRunService: + """Runs batch workflows and reports progress over WebSocket.""" + + def __init__(self) -> None: + self.logger = logging.getLogger(__name__) + + async def run_batch( + self, + session_id: str, + yaml_file: str, + tasks: List[BatchTask], + websocket_manager, + *, + max_parallel: int = 5, + file_base: str = "batch", + log_level: Optional[LogLevel] = None, + ) -> None: + batch_id = session_id + total = len(tasks) + + await websocket_manager.send_message( + session_id, + {"type": "batch_started", "data": {"batch_id": batch_id, "total": total}}, + ) + + semaphore = asyncio.Semaphore(max_parallel) + success_count = 0 + failure_count = 0 + result_rows: List[Dict[str, Any]] = [] + result_lock = asyncio.Lock() + + async def run_task(task: BatchTask) -> None: + nonlocal success_count, failure_count + task_id = task.task_id or str(uuid.uuid4()) + task_dir = self._sanitize_label(f"{file_base}-{task_id}") + + await websocket_manager.send_message( + session_id, + { + "type": "batch_task_started", + "data": { + "row_index": task.row_index, + "task_id": task_id, + "task_dir": task_dir, + }, + }, + ) + + try: + result = await asyncio.to_thread( + self._run_single_task, + session_id, + yaml_file, + task, + task_dir, + log_level, + ) + success_count += 1 + async with result_lock: + result_rows.append( + { + "row_index": task.row_index, + "task_id": task_id, + "task_dir": task_dir, + "status": "success", + "duration_ms": result["duration_ms"], + "token_usage": result["token_usage"], + "graph_output": result["graph_output"], + "results": result["results"], + "error": "", + } + ) + await websocket_manager.send_message( + session_id, + { + "type": "batch_task_completed", + "data": { + "row_index": task.row_index, + "task_id": task_id, + "task_dir": task_dir, + "results": result["results"], + "token_usage": result["token_usage"], + "duration_ms": result["duration_ms"], + }, + }, + ) + except Exception as exc: + failure_count += 1 + async with result_lock: + result_rows.append( + { + "row_index": task.row_index, + "task_id": task_id, + "task_dir": task_dir, + "status": "failed", + "duration_ms": None, + "token_usage": None, + "graph_output": "", + "results": None, + "error": str(exc), + } + ) + await websocket_manager.send_message( + session_id, + { + "type": "batch_task_failed", + "data": { + "row_index": task.row_index, + "task_id": task_id, + "task_dir": task_dir, + "error": str(exc), + }, + }, + ) + + async def run_with_limit(task: BatchTask) -> None: + async with semaphore: + await run_task(task) + + await asyncio.gather(*(run_with_limit(task) for task in tasks)) + + self._write_batch_outputs(session_id, result_rows) + + await websocket_manager.send_message( + session_id, + { + "type": "batch_completed", + "data": { + "batch_id": batch_id, + "total": total, + "succeeded": success_count, + "failed": failure_count, + }, + }, + ) + + def _write_batch_outputs(self, session_id: str, result_rows: List[Dict[str, Any]]) -> None: + output_root = WARE_HOUSE_DIR / f"session_{session_id}" + output_root.mkdir(parents=True, exist_ok=True) + + csv_path = output_root / "batch_results.csv" + json_path = output_root / "batch_manifest.json" + + fieldnames = [ + "row_index", + "task_id", + "task_dir", + "status", + "duration_ms", + "token_usage", + "results", + "error", + ] + + with csv_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in result_rows: + row_copy = dict(row) + row_copy["token_usage"] = json.dumps(row_copy.get("token_usage")) + row_copy["results"] = row_copy.get("graph_output", "") + writer.writerow(row_copy) + + with json_path.open("w", encoding="utf-8") as handle: + json.dump(result_rows, handle, ensure_ascii=True, indent=2) + + def _run_single_task( + self, + session_id: str, + yaml_file: str, + task: BatchTask, + task_dir: str, + log_level: Optional[LogLevel], + ) -> Dict[str, Any]: + yaml_path = self._resolve_yaml_path(yaml_file) + design = load_config(yaml_path, vars_override=task.vars_override or None) + if any(node.type == "human" for node in design.graph.nodes): + raise ValidationError( + "Batch execution does not support human nodes", + details={"yaml_file": yaml_file}, + ) + + output_root = WARE_HOUSE_DIR / f"session_{session_id}" + graph_config = GraphConfig.from_definition( + design.graph, + name=task_dir, + output_root=output_root, + source_path=str(yaml_path), + vars=design.vars, + ) + graph_config.metadata["fixed_output_dir"] = True + + if log_level: + graph_config.log_level = log_level + graph_config.definition.log_level = log_level + + graph_context = GraphContext(config=graph_config) + + start_time = time.perf_counter() + executor = GraphExecutor(graph_context, session_id=session_id) + task_input = self._build_task_input(executor.attachment_store, task) + executor._execute(task_input) + duration_ms = int((time.perf_counter() - start_time) * 1000) + + return { + "results": executor.outputs, + "token_usage": executor.token_tracker.get_token_usage(), + "duration_ms": duration_ms, + "graph_output": executor.get_final_output(), + } + + @staticmethod + def _build_task_input(attachment_store, task: BatchTask): + if task.attachment_paths: + builder = TaskInputBuilder(attachment_store) + return builder.build_from_file_paths(task.task_prompt, task.attachment_paths) + return task.task_prompt + + @staticmethod + def _sanitize_label(value: str) -> str: + cleaned = re.sub(r"[^a-zA-Z0-9._-]+", "_", value) + return cleaned.strip("_") or "task" + + @staticmethod + def _resolve_yaml_path(yaml_filename: str) -> Path: + safe_name = validate_workflow_filename(yaml_filename, require_yaml_extension=True) + yaml_path = YAML_DIR / safe_name + if not yaml_path.exists(): + raise ValidationError("YAML file not found", details={"yaml_file": safe_name}) + return yaml_path diff --git a/uv.lock b/uv.lock index e2de06ed..d79bc2d9 100755 --- a/uv.lock +++ b/uv.lock @@ -390,6 +390,7 @@ dependencies = [ { name = "networkx" }, { name = "numpy" }, { name = "openai" }, + { name = "openpyxl" }, { name = "pandas" }, { name = "pydantic" }, { name = "pygame" }, @@ -422,6 +423,7 @@ requires-dist = [ { name = "networkx" }, { name = "numpy", specifier = ">=2.3.5" }, { name = "openai" }, + { name = "openpyxl", specifier = ">=3.1.2" }, { name = "pandas", specifier = ">=2.3.3" }, { name = "pydantic", specifier = "==2.12.5" }, { name = "pygame", specifier = ">=2.6.1" }, @@ -494,6 +496,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" }, ] +[[package]] +name = "et-xmlfile" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/38/af70d7ab1ae9d4da450eeec1fa3918940a5fafb9055e934af8d6eb0c2313/et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54", size = 17234, upload-time = "2024-10-25T17:25:40.039Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa", size = 18059, upload-time = "2024-10-25T17:25:39.051Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.1" @@ -1147,6 +1158,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/cf/03675d8bd8ecbf4445504d8071adab19f5f993676795708e36402ab38263/openapi_pydantic-0.5.1-py3-none-any.whl", hash = "sha256:a3a09ef4586f5bd760a8df7f43028b60cafb6d9f61de2acba9574766255ab146", size = 96381, upload-time = "2025-01-08T19:29:25.275Z" }, ] +[[package]] +name = "openpyxl" +version = "3.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "et-xmlfile" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/f9/88d94a75de065ea32619465d2f77b29a0469500e99012523b91cc4141cd1/openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050", size = 186464, upload-time = "2024-06-28T14:03:44.161Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2", size = 250910, upload-time = "2024-06-28T14:03:41.161Z" }, +] + [[package]] name = "opentelemetry-api" version = "1.39.1" diff --git a/workflow/graph_context.py b/workflow/graph_context.py index f43c1019..56fcc3be 100755 --- a/workflow/graph_context.py +++ b/workflow/graph_context.py @@ -61,7 +61,8 @@ class GraphContext: # Output directory timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - if "session_" in config.name: + fixed_output_dir = bool(config.metadata.get("fixed_output_dir")) + if fixed_output_dir or "session_" in config.name: self.directory = config.output_root / config.name else: self.directory = config.output_root / f"{config.name}_{timestamp}"