[feat] add server endpoint

This commit is contained in:
NA-Wen 2026-03-11 10:49:52 +08:00
parent c7bb2df9e5
commit 9381abd96f
11 changed files with 674 additions and 81 deletions

View File

@ -312,6 +312,8 @@ class McpRemoteConfig(BaseConfig):
server: str
headers: Dict[str, str] = field(default_factory=dict)
timeout: float | None = None
cache_ttl: float = 0.0
tool_sources: List[str] | None = None
FIELD_SPECS = {
"server": ConfigFieldSpec(
@ -337,6 +339,22 @@ class McpRemoteConfig(BaseConfig):
description="Per-request timeout in seconds",
advance=True,
),
"cache_ttl": ConfigFieldSpec(
name="cache_ttl",
display_name="Tool Cache TTL",
type_hint="float",
required=False,
description="Seconds to cache MCP tool list; 0 disables cache for hot updates",
advance=True,
),
"tool_sources": ConfigFieldSpec(
name="tool_sources",
display_name="Tool Sources Filter",
type_hint="list[str]",
required=False,
description="Only include MCP tools whose meta.source is in this list; omit to default to ['mcp_tools'].",
advance=True,
),
}
@classmethod
@ -360,7 +378,40 @@ class McpRemoteConfig(BaseConfig):
else:
raise ConfigError("timeout must be numeric", extend_path(path, "timeout"))
return cls(server=server, headers=headers, timeout=timeout, path=path)
cache_ttl_value = mapping.get("cache_ttl", 0.0)
if cache_ttl_value is None:
cache_ttl = 0.0
elif isinstance(cache_ttl_value, (int, float)):
cache_ttl = float(cache_ttl_value)
else:
raise ConfigError("cache_ttl must be numeric", extend_path(path, "cache_ttl"))
tool_sources_raw = mapping.get("tool_sources")
tool_sources: List[str] | None = None
if tool_sources_raw is not None:
entries = ensure_list(tool_sources_raw)
normalized: List[str] = []
for idx, entry in enumerate(entries):
if not isinstance(entry, str):
raise ConfigError(
"tool_sources must be a list of strings",
extend_path(path, f"tool_sources[{idx}]"),
)
value = entry.strip()
if value:
normalized.append(value)
tool_sources = normalized
else:
tool_sources = ["mcp_tools"]
return cls(
server=server,
headers=headers,
timeout=timeout,
cache_ttl=cache_ttl,
tool_sources=tool_sources,
path=path,
)
def cache_key(self) -> str:
payload = (
@ -380,6 +431,7 @@ class McpLocalConfig(BaseConfig):
inherit_env: bool = True
startup_timeout: float = 10.0
wait_for_log: str | None = None
cache_ttl: float = 0.0
FIELD_SPECS = {
"command": ConfigFieldSpec(
@ -438,6 +490,14 @@ class McpLocalConfig(BaseConfig):
description="Regex that marks readiness when matched against stdout",
advance=True,
),
"cache_ttl": ConfigFieldSpec(
name="cache_ttl",
display_name="Tool Cache TTL",
type_hint="float",
required=False,
description="Seconds to cache MCP tool list; 0 disables cache for hot updates",
advance=True,
),
}
@classmethod
@ -474,6 +534,13 @@ class McpLocalConfig(BaseConfig):
raise ConfigError("startup_timeout must be numeric", extend_path(path, "startup_timeout"))
wait_for_log = optional_str(mapping, "wait_for_log", path)
cache_ttl_value = mapping.get("cache_ttl", 0.0)
if cache_ttl_value is None:
cache_ttl = 0.0
elif isinstance(cache_ttl_value, (int, float)):
cache_ttl = float(cache_ttl_value)
else:
raise ConfigError("cache_ttl must be numeric", extend_path(path, "cache_ttl"))
return cls(
command=command,
args=normalized_args,
@ -482,6 +549,7 @@ class McpLocalConfig(BaseConfig):
inherit_env=bool(inherit_env),
startup_timeout=startup_timeout,
wait_for_log=wait_for_log,
cache_ttl=cache_ttl,
path=path,
)

View File

@ -69,7 +69,7 @@ export async function postYaml(filename, content) {
export async function updateYaml(filename, content) {
try {
const yamlFilename = addYamlSuffix(filename)
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}`), {
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}/update`), {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
@ -197,7 +197,7 @@ export async function fetchWorkflowsWithDesc() {
const filesWithDesc = await Promise.all(
data.workflows.map(async (filename) => {
try {
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}`))
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}/get`))
const fileData = await response.json()
return {
name: filename,
@ -234,7 +234,7 @@ export async function fetchWorkflowsWithDesc() {
// Fetch YAML file content
export async function fetchWorkflowYAML(filename) {
try {
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}`))
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}/get`))
if (!response.ok) {
throw new Error(`Failed to load YAML file: ${filename}, status: ${response.status}`)
}
@ -250,7 +250,7 @@ export async function fetchWorkflowYAML(filename) {
export async function fetchYaml(filename) {
try {
const yamlFilename = addYamlSuffix(filename)
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}`))
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}/get`))
const data = await response.json().catch(() => ({}))

View File

@ -9,6 +9,7 @@ import logging
import mimetypes
import os
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Mapping, Sequence
@ -35,13 +36,19 @@ class _FunctionManagerCacheEntry:
auto_loaded: bool = False
@dataclass
class _McpToolCacheEntry:
tools: List[Any]
fetched_at: float
class ToolManager:
"""Manage function tools for agent nodes."""
def __init__(self) -> None:
self._functions_dir: Path = FUNCTION_CALLING_DIR
self._function_managers: Dict[Path, _FunctionManagerCacheEntry] = {}
self._mcp_tool_cache: Dict[str, List[Any]] = {}
self._mcp_tool_cache: Dict[str, _McpToolCacheEntry] = {}
self._mcp_stdio_clients: Dict[str, "_StdioClientWrapper"] = {}
def _get_function_manager(self) -> FunctionManager:
@ -192,19 +199,25 @@ class ToolManager:
def _build_mcp_remote_specs(self, config: McpRemoteConfig) -> List[ToolSpec]:
cache_key = f"remote:{config.cache_key()}"
tools = self._mcp_tool_cache.get(cache_key)
if tools is None:
tools = asyncio.run(
tools = self._get_mcp_tools(
cache_key,
config.cache_ttl,
lambda: asyncio.run(
self._fetch_mcp_tools_http(
config.server,
headers=config.headers,
timeout=config.timeout,
)
)
self._mcp_tool_cache[cache_key] = tools
),
)
specs: List[ToolSpec] = []
allowed_sources = {source for source in (config.tool_sources or []) if source}
for tool in tools:
meta = getattr(tool, "meta", None)
source = meta.get("source") if isinstance(meta, Mapping) else None
if allowed_sources and source not in allowed_sources:
continue
specs.append(
ToolSpec(
name=tool.name,
@ -221,10 +234,11 @@ class ToolManager:
raise ValueError("MCP local configuration missing launch key")
cache_key = f"stdio:{launch_key}"
tools = self._mcp_tool_cache.get(cache_key)
if tools is None:
tools = asyncio.run(self._fetch_mcp_tools_stdio(config, launch_key))
self._mcp_tool_cache[cache_key] = tools
tools = self._get_mcp_tools(
cache_key,
config.cache_ttl,
lambda: asyncio.run(self._fetch_mcp_tools_stdio(config, launch_key)),
)
specs: List[ToolSpec] = []
for tool in tools:
@ -238,28 +252,25 @@ class ToolManager:
)
return specs
def _execute_function_tool(
def _get_mcp_tools(
self,
tool_name: str,
arguments: Dict[str, Any],
config: FunctionToolConfig,
tool_context: Dict[str, Any] | None = None,
) -> Any:
mgr = self._get_function_manager()
if config.auto_load:
mgr.load_functions()
func = mgr.get_function(tool_name)
if func is None:
raise ValueError(f"Tool {tool_name} not found in {self._functions_dir}")
cache_key: str,
cache_ttl: float,
fetcher,
) -> List[Any]:
entry = self._mcp_tool_cache.get(cache_key)
if entry and self._is_cache_fresh(entry.fetched_at, cache_ttl):
return entry.tools
tools = fetcher()
self._mcp_tool_cache[cache_key] = _McpToolCacheEntry(tools=tools, fetched_at=time.time())
return tools
@staticmethod
def _is_cache_fresh(fetched_at: float, cache_ttl: float) -> bool:
if cache_ttl <= 0:
return False
return (time.time() - fetched_at) < cache_ttl
call_args = dict(arguments or {})
if (
tool_context is not None
# and "_context" not in call_args
and self._function_accepts_context(func)
):
call_args["_context"] = tool_context
return func(**call_args)
def _function_accepts_context(self, func: Any) -> bool:
try:

View File

@ -1,6 +1,8 @@
from fastapi import FastAPI
from server.bootstrap import init_app
from utils.env_loader import load_dotenv_file
load_dotenv_file()
app = FastAPI(title="DevAll Workflow Server", version="1.0.0")
init_app(app)

View File

@ -1,23 +1,18 @@
"""Application bootstrap helpers for the FastAPI server."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from server import state
from server.config_schema_router import router as config_schema_router
from server.routes import ALL_ROUTERS
from utils.error_handler import add_exception_handlers
from utils.middleware import add_middleware
def init_app(app: FastAPI) -> None:
"""Apply shared middleware, routers, and global state to ``app``."""
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
add_exception_handlers(app)
add_middleware(app)
state.init_state()
for router in ALL_ROUTERS:
app.include_router(router)
app.include_router(config_schema_router)
add_middleware(app)

189
server/mcp_server.py Normal file
View File

@ -0,0 +1,189 @@
import importlib.util
import inspect
import random
import uuid
from pathlib import Path
from typing import Any, Iterable
from fastmcp import FastMCP
from fastmcp.tools import FunctionTool
from starlette.requests import Request
from starlette.responses import JSONResponse
# Repo root (for loading built-in function tools).
_REPO_ROOT = Path(__file__).resolve().parents[1]
_FUNCTION_CALLING_DIR = _REPO_ROOT / "functions" / "function_calling"
# MCP server for production use (supports hot-updated tools).
mcp = FastMCP(
"DevAll MCP Server",
debug=True,
)
_DYNAMIC_TOOLS_DIR = Path(__file__).parent / "dynamic_tools"
_DYNAMIC_TOOLS_DIR.mkdir(parents=True, exist_ok=True)
def _safe_tool_filename(filename: str) -> str:
name = Path(filename).name
if not name.endswith(".py"):
raise ValueError("filename must end with .py")
return name
def _load_module_from_path(path: Path) -> Any:
module_name = f"_dynamic_mcp_{path.stem}_{uuid.uuid4().hex}"
spec = importlib.util.spec_from_file_location(module_name, path)
if spec is None or spec.loader is None:
raise ValueError("failed to create module spec")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _select_functions(module: Any, allowlist: Iterable[str] | None) -> list[Any]:
allowset = {name.strip() for name in allowlist} if allowlist else None
selected = []
for name, obj in inspect.getmembers(module, inspect.isfunction):
if name.startswith("_"):
continue
if getattr(obj, "__module__", None) != module.__name__:
continue
if allowset is not None and name not in allowset:
continue
selected.append(obj)
if allowset:
missing = sorted(allowset - {fn.__name__ for fn in selected})
if missing:
raise ValueError(f"functions not found in module: {', '.join(missing)}")
return selected
def _register_functions(functions: Iterable[Any], *, replace: bool, source: str) -> list[str]:
added: list[str] = []
for fn in functions:
tool = FunctionTool.from_function(fn, meta={"source": source})
if replace:
try:
mcp.remove_tool(tool.name)
except Exception:
pass
mcp.add_tool(tool)
added.append(tool.name)
return added
def register_tools_from_payload(payload: dict) -> dict:
filename = payload.get("filename")
content = payload.get("content")
function_names = payload.get("functions")
replace = payload.get("replace", True)
if not isinstance(filename, str) or not filename.strip():
raise ValueError("filename is required")
if not isinstance(content, str) or not content.strip():
raise ValueError("content is required")
if function_names is not None and not isinstance(function_names, list):
raise ValueError("functions must be a list of strings")
safe_name = _safe_tool_filename(filename.strip())
file_path = _DYNAMIC_TOOLS_DIR / safe_name
file_path.write_text(content, encoding="utf-8")
module = _load_module_from_path(file_path)
functions = _select_functions(module, function_names)
if not functions:
raise ValueError("no functions found to register")
added = _register_functions(functions, replace=bool(replace), source="dynamic_tools")
return {"status": "ok", "added": added, "file": str(file_path)}
@mcp.custom_route("/admin/tools/upload", methods=["POST"])
async def upload_tool(request: Request) -> JSONResponse:
try:
payload = await request.json()
result = register_tools_from_payload(payload)
except ValueError as exc:
return JSONResponse({"error": str(exc)}, status_code=400)
except Exception as exc:
return JSONResponse({"error": f"failed to load tools: {exc}"}, status_code=400)
return JSONResponse(result)
@mcp.custom_route("/admin/tools/reload", methods=["POST"])
async def reload_tools(request: Request) -> JSONResponse:
replace = True
try:
payload = await request.json()
if isinstance(payload, dict) and "replace" in payload:
replace = bool(payload["replace"])
except Exception:
replace = True
result = load_tools_from_directory(replace=replace)
return JSONResponse(result)
@mcp.custom_route("/admin/tools/list", methods=["GET"])
async def list_tools(request: Request) -> JSONResponse:
tools = await mcp.get_tools()
items = []
for key in sorted(tools.keys()):
tool = tools[key]
meta = tool.meta or {}
source = meta.get("source", "unknown")
call_methods = ["mcp_remote"]
if source == "function_calling":
call_methods.append("function")
payload = {
"key": key,
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
"enabled": tool.enabled,
"source": source,
"call_methods": call_methods,
}
if meta:
payload["meta"] = meta
if tool.output_schema is not None:
payload["output_schema"] = tool.output_schema
items.append(payload)
return JSONResponse({"tools": items})
def load_tools_from_directory(*, replace: bool = True) -> dict:
added: list[str] = []
errors: list[str] = []
for directory in (_FUNCTION_CALLING_DIR, _DYNAMIC_TOOLS_DIR):
if not directory.exists():
continue
source = "local_tools" if directory == _FUNCTION_CALLING_DIR else "mcp_tools"
for path in sorted(directory.glob("*.py")):
if path.name.startswith("_") or path.name == "__init__.py":
continue
try:
module = _load_module_from_path(path)
functions = _select_functions(module, None)
if not functions:
continue
added.extend(_register_functions(functions, replace=replace, source=source))
except Exception as exc:
errors.append(f"{path.name}: {exc}")
return {"added": added, "errors": errors}
_bootstrap = load_tools_from_directory(replace=True)
if _bootstrap["errors"]:
print(f"Dynamic tool load errors: {_bootstrap['errors']}")
if __name__ == "__main__":
print("Starting DevAll MCP server...")
print("Run standalone with: fastmcp run server/mcp_server.py --transport streamable-http --port 8010")
mcp.run()

View File

@ -1,6 +1,6 @@
"""Pydantic models shared across server routes."""
from typing import List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, constr
@ -13,6 +13,15 @@ class WorkflowRequest(BaseModel):
log_level: Literal["INFO", "DEBUG"] = "INFO"
class WorkflowRunRequest(BaseModel):
yaml_file: str
task_prompt: str
attachments: Optional[List[str]] = None
session_name: Optional[str] = None
variables: Optional[Dict[str, Any]] = None
log_level: Optional[Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]] = None
class WorkflowUploadContentRequest(BaseModel):
filename: str
content: str

View File

@ -1,6 +1,6 @@
"""Aggregates API routers."""
from . import artifacts, batch, execute, health, sessions, uploads, vuegraphs, workflows, websocket
from . import artifacts, execute, execute_sync, health, sessions, uploads, vuegraphs, workflows, websocket, batch
ALL_ROUTERS = [
health.router,
@ -11,7 +11,8 @@ ALL_ROUTERS = [
sessions.router,
batch.router,
execute.router,
execute_sync.router,
websocket.router,
]
__all__ = ["ALL_ROUTERS"]
__all__ = ["ALL_ROUTERS"]

View File

@ -0,0 +1,253 @@
from __future__ import annotations
import json
import queue
import threading
from datetime import datetime
from pathlib import Path
from typing import Any, Optional, Sequence, Union
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from starlette.concurrency import run_in_threadpool
from check.check import load_config
from entity.enums import LogLevel
from entity.graph_config import GraphConfig
from entity.messages import Message
from runtime.bootstrap.schema import ensure_schema_registry_populated
from runtime.sdk import OUTPUT_ROOT, run_workflow
from server.models import WorkflowRunRequest
from server.settings import YAML_DIR
from utils.attachments import AttachmentStore
from utils.exceptions import ValidationError, WorkflowExecutionError
from utils.logger import WorkflowLogger
from utils.structured_logger import get_server_logger, LogType
from utils.task_input import TaskInputBuilder
from workflow.graph import GraphExecutor
from workflow.graph_context import GraphContext
router = APIRouter()
_SSE_CONTENT_TYPE = "text/event-stream"
def _normalize_session_name(yaml_path: Path, session_name: Optional[str]) -> str:
if session_name and session_name.strip():
return session_name.strip()
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
return f"sdk_{yaml_path.stem}_{timestamp}"
def _resolve_yaml_path(yaml_file: Union[str, Path]) -> Path:
candidate = Path(yaml_file).expanduser()
if candidate.is_absolute():
return candidate
if candidate.exists():
return candidate
repo_root = Path(__file__).resolve().parents[2]
yaml_root = YAML_DIR if YAML_DIR.is_absolute() else (repo_root / YAML_DIR)
return (yaml_root / candidate).expanduser()
def _build_task_input(
graph_context: GraphContext,
prompt: str,
attachments: Sequence[Union[str, Path]],
) -> Union[str, list[Message]]:
if not attachments:
return prompt
attachments_dir = graph_context.directory / "code_workspace" / "attachments"
attachments_dir.mkdir(parents=True, exist_ok=True)
store = AttachmentStore(attachments_dir)
builder = TaskInputBuilder(store)
normalized_paths = [str(Path(path).expanduser()) for path in attachments]
return builder.build_from_file_paths(prompt, normalized_paths)
def _run_workflow_with_logger(
*,
yaml_file: Union[str, Path],
task_prompt: str,
attachments: Optional[Sequence[Union[str, Path]]],
session_name: Optional[str],
variables: Optional[dict],
log_level: Optional[LogLevel],
log_callback,
) -> tuple[Optional[Message], dict[str, Any]]:
ensure_schema_registry_populated()
yaml_path = _resolve_yaml_path(yaml_file)
if not yaml_path.exists():
raise FileNotFoundError(f"YAML file not found: {yaml_path}")
attachments = attachments or []
if (not task_prompt or not task_prompt.strip()) and not attachments:
raise ValidationError(
"Task prompt cannot be empty",
details={"task_prompt_provided": bool(task_prompt)},
)
design = load_config(yaml_path, vars_override=variables)
normalized_session = _normalize_session_name(yaml_path, session_name)
graph_config = GraphConfig.from_definition(
design.graph,
name=normalized_session,
output_root=OUTPUT_ROOT,
source_path=str(yaml_path),
vars=design.vars,
)
if log_level:
graph_config.log_level = log_level
graph_config.definition.log_level = log_level
graph_context = GraphContext(config=graph_config)
task_input = _build_task_input(graph_context, task_prompt, attachments)
class _StreamingWorkflowLogger(WorkflowLogger):
def add_log(self, *args, **kwargs):
entry = super().add_log(*args, **kwargs)
if entry:
payload = entry.to_dict()
payload.pop("details", None)
log_callback("log", payload)
return entry
class _StreamingExecutor(GraphExecutor):
def _create_logger(self) -> WorkflowLogger:
level = log_level or self.graph.log_level
return _StreamingWorkflowLogger(
self.graph.name,
level,
use_structured_logging=True,
log_to_console=False,
)
executor = _StreamingExecutor(graph_context, session_id=normalized_session)
executor._execute(task_input)
final_message = executor.get_final_output_message()
logger = executor.log_manager.get_logger() if executor.log_manager else None
log_id = logger.workflow_id if logger else None
token_usage = executor.token_tracker.get_token_usage() if executor.token_tracker else None
meta = {
"session_name": normalized_session,
"yaml_file": str(yaml_path),
"log_id": log_id,
"token_usage": token_usage,
"output_dir": graph_context.directory,
}
return final_message, meta
def _sse_event(event_type: str, data: Any) -> str:
payload = json.dumps(data, ensure_ascii=False, default=str)
return f"event: {event_type}\ndata: {payload}\n\n"
@router.post("/api/workflow/run")
async def run_workflow_sync(request: WorkflowRunRequest, http_request: Request):
try:
resolved_log_level: Optional[LogLevel] = None
if request.log_level:
resolved_log_level = LogLevel(request.log_level)
except ValueError:
raise HTTPException(
status_code=400,
detail="log_level must be one of DEBUG, INFO, WARNING, ERROR, CRITICAL",
)
accepts_stream = _SSE_CONTENT_TYPE in (http_request.headers.get("accept") or "")
if not accepts_stream:
try:
result = await run_in_threadpool(
run_workflow,
request.yaml_file,
task_prompt=request.task_prompt,
attachments=request.attachments,
session_name=request.session_name,
variables=request.variables,
log_level=resolved_log_level,
)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc))
except ValidationError as exc:
raise HTTPException(status_code=400, detail=str(exc))
except Exception as exc:
logger = get_server_logger()
logger.log_exception(exc, "Failed to run workflow via sync API")
raise WorkflowExecutionError(f"Failed to run workflow: {exc}")
final_message = result.final_message.text_content() if result.final_message else ""
meta = result.meta_info
logger = get_server_logger()
logger.info(
"Workflow execution completed via sync API",
log_type=LogType.WORKFLOW,
session_id=meta.session_name,
yaml_path=meta.yaml_file,
)
return {
"status": "completed",
"final_message": final_message,
"token_usage": meta.token_usage,
"output_dir": str(meta.output_dir.resolve()),
}
event_queue: queue.Queue[tuple[str, Any]] = queue.Queue()
done_event = threading.Event()
def enqueue(event_type: str, data: Any) -> None:
event_queue.put((event_type, data))
def worker() -> None:
try:
enqueue(
"started",
{"yaml_file": request.yaml_file, "task_prompt": request.task_prompt},
)
final_message, meta = _run_workflow_with_logger(
yaml_file=request.yaml_file,
task_prompt=request.task_prompt,
attachments=request.attachments,
session_name=request.session_name,
variables=request.variables,
log_level=resolved_log_level,
log_callback=enqueue,
)
enqueue(
"completed",
{
"status": "completed",
"final_message": final_message.text_content() if final_message else "",
"token_usage": meta["token_usage"],
"output_dir": str(meta["output_dir"].resolve()),
},
)
except (FileNotFoundError, ValidationError) as exc:
enqueue("error", {"message": str(exc)})
except Exception as exc:
logger = get_server_logger()
logger.log_exception(exc, "Failed to run workflow via streaming API")
enqueue("error", {"message": f"Failed to run workflow: {exc}"})
finally:
done_event.set()
threading.Thread(target=worker, daemon=True).start()
async def stream():
while True:
try:
event_type, data = event_queue.get(timeout=0.1)
yield _sse_event(event_type, data)
except queue.Empty:
if done_event.is_set():
break
return StreamingResponse(stream(), media_type=_SSE_CONTENT_TYPE)

View File

@ -1,4 +1,5 @@
from fastapi import APIRouter, HTTPException
from typing import Any
from server.models import (
WorkflowCopyRequest,
@ -67,6 +68,91 @@ async def list_workflows():
return {"workflows": [file.name for file in YAML_DIR.glob("*.yaml")]}
@router.get("/api/workflows/{filename}/args")
async def get_workflow_args(filename: str):
print(str)
try:
safe_filename = validate_workflow_filename(filename, require_yaml_extension=True)
print(safe_filename)
file_path = YAML_DIR / safe_filename
if not file_path.exists() or not file_path.is_file():
raise ResourceNotFoundError(
"Workflow file not found",
resource_type="workflow",
resource_id=safe_filename,
)
# Load and validate YAML content
raw_content = file_path.read_text(encoding="utf-8")
_, yaml_content = validate_workflow_content(safe_filename, raw_content)
args: list[dict[str, Any]] = []
if isinstance(yaml_content, dict):
graph = yaml_content.get("graph") or {}
if isinstance(graph, dict):
raw_args = graph.get("args") or []
if isinstance(raw_args, list):
if len(raw_args) == 0:
raise ResourceNotFoundError(
"Workflow file does not have args",
resource_type="workflow",
resource_id=safe_filename,
)
for item in raw_args:
# Each item is expected to be like: { arg_name: [ {key: value}, ... ] }
if not isinstance(item, dict) or len(item) != 1:
continue
(arg_name, spec_list), = item.items()
if not isinstance(arg_name, str):
continue
arg_info: dict[str, Any] = {"name": arg_name}
if isinstance(spec_list, list):
for spec in spec_list:
if isinstance(spec, dict):
for key, value in spec.items():
# Later entries override earlier ones if duplicated
arg_info[str(key)] = value
args.append(arg_info)
logger = get_server_logger()
logger.info(
"Workflow args retrieved",
log_type=LogType.WORKFLOW,
filename=safe_filename,
args_count=len(args),
)
return {"args": args}
except ValidationError as exc:
# 参数或文件名等校验错误
raise HTTPException(
status_code=400,
detail={"message": str(exc)},
)
except SecurityError as exc:
# 安全相关错误(例如路径遍历)
raise HTTPException(
status_code=400,
detail={"message": str(exc)},
)
except ResourceNotFoundError as exc:
# 文件不存在
raise HTTPException(
status_code=404,
detail={"message": str(exc)},
)
except Exception as exc:
logger = get_server_logger()
logger.log_exception(exc, f"Unexpected error retrieving workflow args: {filename}")
# 兜底错误
raise HTTPException(
status_code=500,
detail={"message": f"Failed to retrieve workflow args: {exc}"},
)
@router.post("/api/workflows/upload/content")
async def upload_workflow_content(request: WorkflowUploadContentRequest):
return _persist_workflow_from_content(
@ -78,7 +164,7 @@ async def upload_workflow_content(request: WorkflowUploadContentRequest):
)
@router.put("/api/workflows/{filename}")
@router.put("/api/workflows/{filename}/update")
async def update_workflow_content(filename: str, request: WorkflowUpdateContentRequest):
return _persist_workflow_from_content(
filename,
@ -89,7 +175,7 @@ async def update_workflow_content(filename: str, request: WorkflowUpdateContentR
)
@router.delete("/api/workflows/{filename}")
@router.delete("/api/workflows/{filename}/delete")
async def delete_workflow(filename: str):
try:
safe_filename = validate_workflow_filename(filename, require_yaml_extension=True)
@ -180,7 +266,7 @@ async def copy_workflow_file(filename: str, request: WorkflowCopyRequest):
raise WorkflowExecutionError(f"Failed to copy workflow: {exc}")
@router.get("/api/workflows/{filename}")
@router.get("/api/workflows/{filename}/get")
async def get_workflow_raw_content(filename: str):
try:
safe_filename = validate_workflow_filename(filename, require_yaml_extension=True)
@ -208,4 +294,4 @@ async def get_workflow_raw_content(filename: str):
except Exception as exc:
logger = get_server_logger()
logger.log_exception(exc, f"Unexpected error retrieving workflow: {filename}")
raise WorkflowExecutionError(f"Failed to retrieve workflow: {exc}")
raise WorkflowExecutionError(f"Failed to retrieve workflow: {exc}")

View File

@ -49,8 +49,6 @@ class WebSocketManager:
):
self.active_connections: Dict[str, WebSocket] = {}
self.connection_timestamps: Dict[str, float] = {}
self.send_locks: Dict[str, asyncio.Lock] = {}
self.loop: asyncio.AbstractEventLoop | None = None
self.session_store = session_store or WorkflowSessionStore()
self.session_controller = session_controller or SessionExecutionController(self.session_store)
self.attachment_service = attachment_service or AttachmentService()
@ -67,16 +65,10 @@ class WebSocketManager:
async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str:
await websocket.accept()
if self.loop is None:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = None
if not session_id:
session_id = str(uuid.uuid4())
self.active_connections[session_id] = websocket
self.connection_timestamps[session_id] = time.time()
self.send_locks[session_id] = asyncio.Lock()
logging.info("WebSocket connected: %s", session_id)
await self.send_message(
session_id,
@ -98,8 +90,6 @@ class WebSocketManager:
del self.active_connections[session_id]
if session_id in self.connection_timestamps:
del self.connection_timestamps[session_id]
if session_id in self.send_locks:
del self.send_locks[session_id]
self.session_controller.cleanup_session(session_id)
remaining_session = self.session_store.get_session(session_id)
if remaining_session and remaining_session.executor is None:
@ -111,12 +101,7 @@ class WebSocketManager:
if session_id in self.active_connections:
websocket = self.active_connections[session_id]
try:
lock = self.send_locks.get(session_id)
if lock is None:
await websocket.send_text(_encode_ws_message(message))
else:
async with lock:
await websocket.send_text(_encode_ws_message(message))
await websocket.send_text(_encode_ws_message(message))
except Exception as exc:
traceback.print_exc()
logging.error("Failed to send message to %s: %s", session_id, exc)
@ -130,13 +115,7 @@ class WebSocketManager:
else:
asyncio.run(self.send_message(session_id, message))
except RuntimeError:
if self.loop and self.loop.is_running():
asyncio.run_coroutine_threadsafe(
self.send_message(session_id, message),
self.loop,
)
else:
asyncio.run(self.send_message(session_id, message))
asyncio.run(self.send_message(session_id, message))
async def broadcast(self, message: Dict[str, Any]) -> None:
for session_id in list(self.active_connections.keys()):