mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-05-31 20:58:24 +00:00
Merge pull request #516 from zxrys/feature/batch-run
feature: batch run
This commit is contained in:
commit
94f557c391
800
frontend/package-lock.json
generated
800
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -5,8 +5,8 @@ import Sidebar from './components/Sidebar.vue'
|
|||||||
|
|
||||||
const route = useRoute()
|
const route = useRoute()
|
||||||
|
|
||||||
// Hide the sidebar on LaunchView
|
// Hide the sidebar on LaunchView, BatchRunView and WorkflowWorkbench
|
||||||
const showSidebar = computed(() => route.path !== '/launch')
|
const showSidebar = computed(() => route.path !== '/launch' && route.path !== '/batch-run')
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
:class="{ active: isWorkflowsActive }"
|
:class="{ active: isWorkflowsActive }"
|
||||||
>Workflows</router-link>
|
>Workflows</router-link>
|
||||||
<router-link to="/launch" target="_blank" rel="noopener">Launch</router-link>
|
<router-link to="/launch" target="_blank" rel="noopener">Launch</router-link>
|
||||||
|
<router-link to="/batch-run" target="_blank" rel="noopener">Labaratory</router-link>
|
||||||
</nav>
|
</nav>
|
||||||
<div class="sidebar-actions">
|
<div class="sidebar-actions">
|
||||||
<button class="settings-nav-btn" @click="showSettingsModal = true" title="Settings">
|
<button class="settings-nav-btn" @click="showSettingsModal = true" title="Settings">
|
||||||
|
|||||||
2738
frontend/src/pages/BatchRunView.vue
Normal file
2738
frontend/src/pages/BatchRunView.vue
Normal file
File diff suppressed because it is too large
Load Diff
@ -800,6 +800,9 @@ const selectWorkflow = (fileName) => {
|
|||||||
isFileSearchDirty.value = false
|
isFileSearchDirty.value = false
|
||||||
closeFileDropdown()
|
closeFileDropdown()
|
||||||
|
|
||||||
|
// Avoid focusing on element after selection
|
||||||
|
fileSelectorInputRef.value?.blur()
|
||||||
|
|
||||||
router.push({
|
router.push({
|
||||||
query: {
|
query: {
|
||||||
...route.query,
|
...route.query,
|
||||||
|
|||||||
@ -13,6 +13,10 @@ const routes = [
|
|||||||
path: '/launch',
|
path: '/launch',
|
||||||
component: () => import('../pages/LaunchView.vue')
|
component: () => import('../pages/LaunchView.vue')
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
path: '/batch-run',
|
||||||
|
component: () => import('../pages/BatchRunView.vue')
|
||||||
|
},
|
||||||
{
|
{
|
||||||
path: '/workflows/:name?',
|
path: '/workflows/:name?',
|
||||||
component: () => import('../pages/WorkflowWorkbench.vue')
|
component: () => import('../pages/WorkflowWorkbench.vue')
|
||||||
|
|||||||
@ -424,6 +424,58 @@ export async function getAttachment(sessionId, attachmentId) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Batch workflow execution
|
||||||
|
export async function postBatchWorkflow({ file, sessionId, yamlFile, maxParallel, logLevel }) {
|
||||||
|
try {
|
||||||
|
const formData = new FormData()
|
||||||
|
|
||||||
|
// Append required parameters
|
||||||
|
if (file) {
|
||||||
|
formData.append('file', file)
|
||||||
|
}
|
||||||
|
if (sessionId) {
|
||||||
|
formData.append('session_id', sessionId)
|
||||||
|
}
|
||||||
|
if (yamlFile) {
|
||||||
|
formData.append('yaml_file', yamlFile)
|
||||||
|
}
|
||||||
|
if (maxParallel !== undefined) {
|
||||||
|
formData.append('max_parallel', maxParallel.toString())
|
||||||
|
}
|
||||||
|
if (logLevel) {
|
||||||
|
formData.append('log_level', logLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(apiUrl('/api/workflows/batch'), {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
})
|
||||||
|
|
||||||
|
const data = await response.json().catch(() => ({}))
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
return {
|
||||||
|
success: true,
|
||||||
|
message: data?.message || 'Batch workflow executed successfully',
|
||||||
|
...data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
detail: data?.detail,
|
||||||
|
message: data?.message || 'Failed to execute batch workflow',
|
||||||
|
status: response.status
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error executing batch workflow:', error)
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
message: 'API error'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Upload a binary file
|
// Upload a binary file
|
||||||
export async function postFile(sessionId, file) {
|
export async function postFile(sessionId, file) {
|
||||||
try {
|
try {
|
||||||
|
|||||||
@ -30,6 +30,7 @@ dependencies = [
|
|||||||
"networkx",
|
"networkx",
|
||||||
"cartopy",
|
"cartopy",
|
||||||
"pandas>=2.3.3",
|
"pandas>=2.3.3",
|
||||||
|
"openpyxl>=3.1.2",
|
||||||
"numpy>=2.3.5",
|
"numpy>=2.3.5",
|
||||||
"seaborn>=0.13.2",
|
"seaborn>=0.13.2",
|
||||||
"google-genai>=1.52.0",
|
"google-genai>=1.52.0",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Aggregates API routers."""
|
"""Aggregates API routers."""
|
||||||
|
|
||||||
from . import artifacts, execute, health, sessions, uploads, vuegraphs, workflows, websocket
|
from . import artifacts, batch, execute, health, sessions, uploads, vuegraphs, workflows, websocket
|
||||||
|
|
||||||
ALL_ROUTERS = [
|
ALL_ROUTERS = [
|
||||||
health.router,
|
health.router,
|
||||||
@ -9,6 +9,7 @@ ALL_ROUTERS = [
|
|||||||
uploads.router,
|
uploads.router,
|
||||||
artifacts.router,
|
artifacts.router,
|
||||||
sessions.router,
|
sessions.router,
|
||||||
|
batch.router,
|
||||||
execute.router,
|
execute.router,
|
||||||
websocket.router,
|
websocket.router,
|
||||||
]
|
]
|
||||||
|
|||||||
61
server/routes/batch.py
Normal file
61
server/routes/batch.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||||
|
|
||||||
|
from entity.enums import LogLevel
|
||||||
|
from server.services.batch_parser import parse_batch_file
|
||||||
|
from server.services.batch_run_service import BatchRunService
|
||||||
|
from server.state import ensure_known_session
|
||||||
|
from utils.exceptions import ValidationError
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/workflows/batch")
|
||||||
|
async def execute_batch(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
session_id: str = Form(...),
|
||||||
|
yaml_file: str = Form(...),
|
||||||
|
max_parallel: int = Form(5),
|
||||||
|
log_level: str | None = Form(None),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
manager = ensure_known_session(session_id, require_connection=True)
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
|
|
||||||
|
if max_parallel < 1:
|
||||||
|
raise HTTPException(status_code=400, detail="max_parallel must be >= 1")
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = await file.read()
|
||||||
|
tasks, file_base = parse_batch_file(content, file.filename or "batch.csv")
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
|
|
||||||
|
resolved_level = None
|
||||||
|
if log_level:
|
||||||
|
try:
|
||||||
|
resolved_level = LogLevel(log_level)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail="log_level must be either DEBUG or INFO")
|
||||||
|
|
||||||
|
service = BatchRunService()
|
||||||
|
asyncio.create_task(
|
||||||
|
service.run_batch(
|
||||||
|
session_id,
|
||||||
|
yaml_file,
|
||||||
|
tasks,
|
||||||
|
manager,
|
||||||
|
max_parallel=max_parallel,
|
||||||
|
file_base=file_base,
|
||||||
|
log_level=resolved_level,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "accepted",
|
||||||
|
"session_id": session_id,
|
||||||
|
"batch_id": session_id,
|
||||||
|
"task_count": len(tasks),
|
||||||
|
}
|
||||||
192
server/services/batch_parser.py
Normal file
192
server/services/batch_parser.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""Parse batch task files (CSV/Excel) into runnable tasks."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from utils.exceptions import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BatchTask:
|
||||||
|
row_index: int
|
||||||
|
task_id: Optional[str]
|
||||||
|
task_prompt: str
|
||||||
|
attachment_paths: List[str]
|
||||||
|
vars_override: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_batch_file(content: bytes, filename: str) -> Tuple[List[BatchTask], str]:
|
||||||
|
"""Parse a CSV/Excel batch file and return tasks plus file base name."""
|
||||||
|
suffix = Path(filename or "").suffix.lower()
|
||||||
|
if suffix not in {".csv", ".xlsx", ".xls"}:
|
||||||
|
raise ValidationError("Unsupported file type; must be .csv or .xlsx/.xls", field="file")
|
||||||
|
|
||||||
|
if suffix == ".csv":
|
||||||
|
df = _read_csv(content)
|
||||||
|
else:
|
||||||
|
df = _read_excel(content)
|
||||||
|
|
||||||
|
file_base = Path(filename).stem or "batch"
|
||||||
|
tasks = _parse_dataframe(df)
|
||||||
|
if not tasks:
|
||||||
|
raise ValidationError("Batch file contains no tasks", field="file")
|
||||||
|
return tasks, file_base
|
||||||
|
|
||||||
|
|
||||||
|
def _read_csv(content: bytes) -> pd.DataFrame:
|
||||||
|
try:
|
||||||
|
import chardet
|
||||||
|
except Exception:
|
||||||
|
chardet = None
|
||||||
|
encoding = "utf-8"
|
||||||
|
if chardet:
|
||||||
|
detected = chardet.detect(content)
|
||||||
|
encoding = detected.get("encoding") or encoding
|
||||||
|
try:
|
||||||
|
return pd.read_csv(BytesIO(content), encoding=encoding)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValidationError(f"Failed to read CSV: {exc}", field="file")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_excel(content: bytes) -> pd.DataFrame:
|
||||||
|
try:
|
||||||
|
return pd.read_excel(BytesIO(content))
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValidationError(f"Failed to read Excel file: {exc}", field="file")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_dataframe(df: pd.DataFrame) -> List[BatchTask]:
|
||||||
|
column_map = {str(col).strip().lower(): col for col in df.columns}
|
||||||
|
id_col = column_map.get("id")
|
||||||
|
task_col = column_map.get("task")
|
||||||
|
attachments_col = column_map.get("attachments")
|
||||||
|
vars_col = column_map.get("vars")
|
||||||
|
|
||||||
|
tasks: List[BatchTask] = []
|
||||||
|
seen_ids: set[str] = set()
|
||||||
|
|
||||||
|
for row_index, row in enumerate(df.to_dict(orient="records"), start=1):
|
||||||
|
task_prompt = _get_cell_text(row, task_col)
|
||||||
|
attachment_paths = _parse_json_list(row, attachments_col, row_index)
|
||||||
|
vars_override = _parse_json_dict(row, vars_col, row_index)
|
||||||
|
|
||||||
|
if not task_prompt and not attachment_paths:
|
||||||
|
raise ValidationError(
|
||||||
|
"Task and attachments cannot both be empty",
|
||||||
|
details={"row_index": row_index},
|
||||||
|
)
|
||||||
|
|
||||||
|
task_id = _get_cell_text(row, id_col)
|
||||||
|
if task_id:
|
||||||
|
if task_id in seen_ids:
|
||||||
|
raise ValidationError(
|
||||||
|
"Duplicate ID in batch file",
|
||||||
|
details={"row_index": row_index, "task_id": task_id},
|
||||||
|
)
|
||||||
|
seen_ids.add(task_id)
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
BatchTask(
|
||||||
|
row_index=row_index,
|
||||||
|
task_id=task_id or None,
|
||||||
|
task_prompt=task_prompt,
|
||||||
|
attachment_paths=attachment_paths,
|
||||||
|
vars_override=vars_override,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cell_text(row: Dict[str, Any], column: Optional[str]) -> str:
|
||||||
|
if not column:
|
||||||
|
return ""
|
||||||
|
value = row.get(column)
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(value, float) and pd.isna(value):
|
||||||
|
return ""
|
||||||
|
if pd.isna(value):
|
||||||
|
return ""
|
||||||
|
return str(value).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_list(
|
||||||
|
row: Dict[str, Any],
|
||||||
|
column: Optional[str],
|
||||||
|
row_index: int,
|
||||||
|
) -> List[str]:
|
||||||
|
if not column:
|
||||||
|
return []
|
||||||
|
raw_value = row.get(column)
|
||||||
|
if raw_value is None or (isinstance(raw_value, float) and pd.isna(raw_value)):
|
||||||
|
return []
|
||||||
|
if isinstance(raw_value, list):
|
||||||
|
return _ensure_string_list(raw_value, row_index, "Attachments")
|
||||||
|
if isinstance(raw_value, str):
|
||||||
|
if not raw_value.strip():
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw_value)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid JSON in Attachments: {exc}",
|
||||||
|
details={"row_index": row_index},
|
||||||
|
)
|
||||||
|
return _ensure_string_list(parsed, row_index, "Attachments")
|
||||||
|
raise ValidationError(
|
||||||
|
"Attachments must be a JSON list",
|
||||||
|
details={"row_index": row_index},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_dict(
|
||||||
|
row: Dict[str, Any],
|
||||||
|
column: Optional[str],
|
||||||
|
row_index: int,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if not column:
|
||||||
|
return {}
|
||||||
|
raw_value = row.get(column)
|
||||||
|
if raw_value is None or (isinstance(raw_value, float) and pd.isna(raw_value)):
|
||||||
|
return {}
|
||||||
|
if isinstance(raw_value, dict):
|
||||||
|
return raw_value
|
||||||
|
if isinstance(raw_value, str):
|
||||||
|
if not raw_value.strip():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw_value)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid JSON in Vars: {exc}",
|
||||||
|
details={"row_index": row_index},
|
||||||
|
)
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValidationError(
|
||||||
|
"Vars must be a JSON object",
|
||||||
|
details={"row_index": row_index},
|
||||||
|
)
|
||||||
|
return parsed
|
||||||
|
raise ValidationError(
|
||||||
|
"Vars must be a JSON object",
|
||||||
|
details={"row_index": row_index},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_string_list(value: Any, row_index: int, field: str) -> List[str]:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValidationError(
|
||||||
|
f"{field} must be a JSON list",
|
||||||
|
details={"row_index": row_index},
|
||||||
|
)
|
||||||
|
result: List[str] = []
|
||||||
|
for item in value:
|
||||||
|
if item is None or (isinstance(item, float) and pd.isna(item)):
|
||||||
|
continue
|
||||||
|
result.append(str(item))
|
||||||
|
return result
|
||||||
255
server/services/batch_run_service.py
Normal file
255
server/services/batch_run_service.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
"""Batch workflow execution helpers."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from check.check import load_config
|
||||||
|
from entity.enums import LogLevel
|
||||||
|
from entity.graph_config import GraphConfig
|
||||||
|
from utils.exceptions import ValidationError
|
||||||
|
from utils.task_input import TaskInputBuilder
|
||||||
|
from workflow.graph import GraphExecutor
|
||||||
|
from workflow.graph_context import GraphContext
|
||||||
|
|
||||||
|
from server.services.batch_parser import BatchTask
|
||||||
|
from server.services.workflow_storage import validate_workflow_filename
|
||||||
|
from server.settings import WARE_HOUSE_DIR, YAML_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRunService:
|
||||||
|
"""Runs batch workflows and reports progress over WebSocket."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def run_batch(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
yaml_file: str,
|
||||||
|
tasks: List[BatchTask],
|
||||||
|
websocket_manager,
|
||||||
|
*,
|
||||||
|
max_parallel: int = 5,
|
||||||
|
file_base: str = "batch",
|
||||||
|
log_level: Optional[LogLevel] = None,
|
||||||
|
) -> None:
|
||||||
|
batch_id = session_id
|
||||||
|
total = len(tasks)
|
||||||
|
|
||||||
|
await websocket_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{"type": "batch_started", "data": {"batch_id": batch_id, "total": total}},
|
||||||
|
)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(max_parallel)
|
||||||
|
success_count = 0
|
||||||
|
failure_count = 0
|
||||||
|
result_rows: List[Dict[str, Any]] = []
|
||||||
|
result_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def run_task(task: BatchTask) -> None:
|
||||||
|
nonlocal success_count, failure_count
|
||||||
|
task_id = task.task_id or str(uuid.uuid4())
|
||||||
|
task_dir = self._sanitize_label(f"{file_base}-{task_id}")
|
||||||
|
|
||||||
|
await websocket_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "batch_task_started",
|
||||||
|
"data": {
|
||||||
|
"row_index": task.row_index,
|
||||||
|
"task_id": task_id,
|
||||||
|
"task_dir": task_dir,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
self._run_single_task,
|
||||||
|
session_id,
|
||||||
|
yaml_file,
|
||||||
|
task,
|
||||||
|
task_dir,
|
||||||
|
log_level,
|
||||||
|
)
|
||||||
|
success_count += 1
|
||||||
|
async with result_lock:
|
||||||
|
result_rows.append(
|
||||||
|
{
|
||||||
|
"row_index": task.row_index,
|
||||||
|
"task_id": task_id,
|
||||||
|
"task_dir": task_dir,
|
||||||
|
"status": "success",
|
||||||
|
"duration_ms": result["duration_ms"],
|
||||||
|
"token_usage": result["token_usage"],
|
||||||
|
"graph_output": result["graph_output"],
|
||||||
|
"results": result["results"],
|
||||||
|
"error": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await websocket_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "batch_task_completed",
|
||||||
|
"data": {
|
||||||
|
"row_index": task.row_index,
|
||||||
|
"task_id": task_id,
|
||||||
|
"task_dir": task_dir,
|
||||||
|
"results": result["results"],
|
||||||
|
"token_usage": result["token_usage"],
|
||||||
|
"duration_ms": result["duration_ms"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
failure_count += 1
|
||||||
|
async with result_lock:
|
||||||
|
result_rows.append(
|
||||||
|
{
|
||||||
|
"row_index": task.row_index,
|
||||||
|
"task_id": task_id,
|
||||||
|
"task_dir": task_dir,
|
||||||
|
"status": "failed",
|
||||||
|
"duration_ms": None,
|
||||||
|
"token_usage": None,
|
||||||
|
"graph_output": "",
|
||||||
|
"results": None,
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await websocket_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "batch_task_failed",
|
||||||
|
"data": {
|
||||||
|
"row_index": task.row_index,
|
||||||
|
"task_id": task_id,
|
||||||
|
"task_dir": task_dir,
|
||||||
|
"error": str(exc),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_with_limit(task: BatchTask) -> None:
|
||||||
|
async with semaphore:
|
||||||
|
await run_task(task)
|
||||||
|
|
||||||
|
await asyncio.gather(*(run_with_limit(task) for task in tasks))
|
||||||
|
|
||||||
|
self._write_batch_outputs(session_id, result_rows)
|
||||||
|
|
||||||
|
await websocket_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "batch_completed",
|
||||||
|
"data": {
|
||||||
|
"batch_id": batch_id,
|
||||||
|
"total": total,
|
||||||
|
"succeeded": success_count,
|
||||||
|
"failed": failure_count,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _write_batch_outputs(self, session_id: str, result_rows: List[Dict[str, Any]]) -> None:
|
||||||
|
output_root = WARE_HOUSE_DIR / f"session_{session_id}"
|
||||||
|
output_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
csv_path = output_root / "batch_results.csv"
|
||||||
|
json_path = output_root / "batch_manifest.json"
|
||||||
|
|
||||||
|
fieldnames = [
|
||||||
|
"row_index",
|
||||||
|
"task_id",
|
||||||
|
"task_dir",
|
||||||
|
"status",
|
||||||
|
"duration_ms",
|
||||||
|
"token_usage",
|
||||||
|
"results",
|
||||||
|
"error",
|
||||||
|
]
|
||||||
|
|
||||||
|
with csv_path.open("w", newline="", encoding="utf-8") as handle:
|
||||||
|
writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore")
|
||||||
|
writer.writeheader()
|
||||||
|
for row in result_rows:
|
||||||
|
row_copy = dict(row)
|
||||||
|
row_copy["token_usage"] = json.dumps(row_copy.get("token_usage"))
|
||||||
|
row_copy["results"] = row_copy.get("graph_output", "")
|
||||||
|
writer.writerow(row_copy)
|
||||||
|
|
||||||
|
with json_path.open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(result_rows, handle, ensure_ascii=True, indent=2)
|
||||||
|
|
||||||
|
def _run_single_task(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
yaml_file: str,
|
||||||
|
task: BatchTask,
|
||||||
|
task_dir: str,
|
||||||
|
log_level: Optional[LogLevel],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
yaml_path = self._resolve_yaml_path(yaml_file)
|
||||||
|
design = load_config(yaml_path, vars_override=task.vars_override or None)
|
||||||
|
if any(node.type == "human" for node in design.graph.nodes):
|
||||||
|
raise ValidationError(
|
||||||
|
"Batch execution does not support human nodes",
|
||||||
|
details={"yaml_file": yaml_file},
|
||||||
|
)
|
||||||
|
|
||||||
|
output_root = WARE_HOUSE_DIR / f"session_{session_id}"
|
||||||
|
graph_config = GraphConfig.from_definition(
|
||||||
|
design.graph,
|
||||||
|
name=task_dir,
|
||||||
|
output_root=output_root,
|
||||||
|
source_path=str(yaml_path),
|
||||||
|
vars=design.vars,
|
||||||
|
)
|
||||||
|
graph_config.metadata["fixed_output_dir"] = True
|
||||||
|
|
||||||
|
if log_level:
|
||||||
|
graph_config.log_level = log_level
|
||||||
|
graph_config.definition.log_level = log_level
|
||||||
|
|
||||||
|
graph_context = GraphContext(config=graph_config)
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
executor = GraphExecutor(graph_context, session_id=session_id)
|
||||||
|
task_input = self._build_task_input(executor.attachment_store, task)
|
||||||
|
executor._execute(task_input)
|
||||||
|
duration_ms = int((time.perf_counter() - start_time) * 1000)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"results": executor.outputs,
|
||||||
|
"token_usage": executor.token_tracker.get_token_usage(),
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"graph_output": executor.get_final_output(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_task_input(attachment_store, task: BatchTask):
|
||||||
|
if task.attachment_paths:
|
||||||
|
builder = TaskInputBuilder(attachment_store)
|
||||||
|
return builder.build_from_file_paths(task.task_prompt, task.attachment_paths)
|
||||||
|
return task.task_prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_label(value: str) -> str:
|
||||||
|
cleaned = re.sub(r"[^a-zA-Z0-9._-]+", "_", value)
|
||||||
|
return cleaned.strip("_") or "task"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_yaml_path(yaml_filename: str) -> Path:
|
||||||
|
safe_name = validate_workflow_filename(yaml_filename, require_yaml_extension=True)
|
||||||
|
yaml_path = YAML_DIR / safe_name
|
||||||
|
if not yaml_path.exists():
|
||||||
|
raise ValidationError("YAML file not found", details={"yaml_file": safe_name})
|
||||||
|
return yaml_path
|
||||||
@ -18,16 +18,31 @@ from utils.structured_logger import get_server_logger, LogType
|
|||||||
|
|
||||||
|
|
||||||
def _update_workflow_id(content: str, workflow_id: str) -> str:
|
def _update_workflow_id(content: str, workflow_id: str) -> str:
|
||||||
pattern = re.compile(r"^(id:\\s*).*$", re.MULTILINE)
|
# Pattern to match graph:\n id: <value>
|
||||||
|
pattern = re.compile(r"(graph:\s*\n\s*id:\s*).*$", re.MULTILINE)
|
||||||
match = pattern.search(content)
|
match = pattern.search(content)
|
||||||
if match:
|
if match:
|
||||||
return pattern.sub(rf"\\1{workflow_id}", content, count=1)
|
# Replace the value after "graph:\n id: "
|
||||||
|
return pattern.sub(rf"\1{workflow_id}", content, count=1)
|
||||||
|
|
||||||
|
# If no graph.id found, look for standalone id: at root level (legacy support)
|
||||||
|
root_id_pattern = re.compile(r"^(id:\s*).*$", re.MULTILINE)
|
||||||
|
root_match = root_id_pattern.search(content)
|
||||||
|
if root_match:
|
||||||
|
return root_id_pattern.sub(rf"\1{workflow_id}", content, count=1)
|
||||||
|
|
||||||
|
# If neither found, add graph.id after graph: section if it exists
|
||||||
|
graph_pattern = re.compile(r"(graph:\s*\n)")
|
||||||
|
graph_match = graph_pattern.search(content)
|
||||||
|
if graph_match:
|
||||||
|
return graph_pattern.sub(rf"\1 id: {workflow_id}\n", content, count=1)
|
||||||
|
|
||||||
|
# Fallback (is invalid)
|
||||||
lines = content.splitlines()
|
lines = content.splitlines()
|
||||||
insert_index = 0
|
insert_index = 0
|
||||||
if lines and lines[0].strip() == "---":
|
if lines and lines[0].strip() == "---":
|
||||||
insert_index = 1
|
insert_index = 1
|
||||||
lines.insert(insert_index, f"id: {workflow_id}")
|
lines.insert(insert_index, f"graph:\n id: {workflow_id}")
|
||||||
updated = "\n".join(lines)
|
updated = "\n".join(lines)
|
||||||
if content.endswith("\n"):
|
if content.endswith("\n"):
|
||||||
updated += "\n"
|
updated += "\n"
|
||||||
|
|||||||
23
uv.lock
generated
23
uv.lock
generated
@ -390,6 +390,7 @@ dependencies = [
|
|||||||
{ name = "networkx" },
|
{ name = "networkx" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
|
{ name = "openpyxl" },
|
||||||
{ name = "pandas" },
|
{ name = "pandas" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "pygame" },
|
{ name = "pygame" },
|
||||||
@ -422,6 +423,7 @@ requires-dist = [
|
|||||||
{ name = "networkx" },
|
{ name = "networkx" },
|
||||||
{ name = "numpy", specifier = ">=2.3.5" },
|
{ name = "numpy", specifier = ">=2.3.5" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
|
{ name = "openpyxl", specifier = ">=3.1.2" },
|
||||||
{ name = "pandas", specifier = ">=2.3.3" },
|
{ name = "pandas", specifier = ">=2.3.3" },
|
||||||
{ name = "pydantic", specifier = "==2.12.5" },
|
{ name = "pydantic", specifier = "==2.12.5" },
|
||||||
{ name = "pygame", specifier = ">=2.6.1" },
|
{ name = "pygame", specifier = ">=2.6.1" },
|
||||||
@ -494,6 +496,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
|
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "et-xmlfile"
|
||||||
|
version = "2.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d3/38/af70d7ab1ae9d4da450eeec1fa3918940a5fafb9055e934af8d6eb0c2313/et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54", size = 17234, upload-time = "2024-10-25T17:25:40.039Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa", size = 18059, upload-time = "2024-10-25T17:25:39.051Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "exceptiongroup"
|
name = "exceptiongroup"
|
||||||
version = "1.3.1"
|
version = "1.3.1"
|
||||||
@ -1147,6 +1158,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/12/cf/03675d8bd8ecbf4445504d8071adab19f5f993676795708e36402ab38263/openapi_pydantic-0.5.1-py3-none-any.whl", hash = "sha256:a3a09ef4586f5bd760a8df7f43028b60cafb6d9f61de2acba9574766255ab146", size = 96381, upload-time = "2025-01-08T19:29:25.275Z" },
|
{ url = "https://files.pythonhosted.org/packages/12/cf/03675d8bd8ecbf4445504d8071adab19f5f993676795708e36402ab38263/openapi_pydantic-0.5.1-py3-none-any.whl", hash = "sha256:a3a09ef4586f5bd760a8df7f43028b60cafb6d9f61de2acba9574766255ab146", size = 96381, upload-time = "2025-01-08T19:29:25.275Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openpyxl"
|
||||||
|
version = "3.1.5"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "et-xmlfile" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/3d/f9/88d94a75de065ea32619465d2f77b29a0469500e99012523b91cc4141cd1/openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050", size = 186464, upload-time = "2024-06-28T14:03:44.161Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2", size = 250910, upload-time = "2024-06-28T14:03:41.161Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opentelemetry-api"
|
name = "opentelemetry-api"
|
||||||
version = "1.39.1"
|
version = "1.39.1"
|
||||||
|
|||||||
@ -61,7 +61,8 @@ class GraphContext:
|
|||||||
|
|
||||||
# Output directory
|
# Output directory
|
||||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
if "session_" in config.name:
|
fixed_output_dir = bool(config.metadata.get("fixed_output_dir"))
|
||||||
|
if fixed_output_dir or "session_" in config.name:
|
||||||
self.directory = config.output_root / config.name
|
self.directory = config.output_root / config.name
|
||||||
else:
|
else:
|
||||||
self.directory = config.output_root / f"{config.name}_{timestamp}"
|
self.directory = config.output_root / f"{config.name}_{timestamp}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user