diff --git a/entity/configs/node/tooling.py b/entity/configs/node/tooling.py index 7a836c04..29bd87b4 100755 --- a/entity/configs/node/tooling.py +++ b/entity/configs/node/tooling.py @@ -312,6 +312,8 @@ class McpRemoteConfig(BaseConfig): server: str headers: Dict[str, str] = field(default_factory=dict) timeout: float | None = None + cache_ttl: float = 0.0 + tool_sources: List[str] | None = None FIELD_SPECS = { "server": ConfigFieldSpec( @@ -337,6 +339,22 @@ class McpRemoteConfig(BaseConfig): description="Per-request timeout in seconds", advance=True, ), + "cache_ttl": ConfigFieldSpec( + name="cache_ttl", + display_name="Tool Cache TTL", + type_hint="float", + required=False, + description="Seconds to cache MCP tool list; 0 disables cache for hot updates", + advance=True, + ), + "tool_sources": ConfigFieldSpec( + name="tool_sources", + display_name="Tool Sources Filter", + type_hint="list[str]", + required=False, + description="Only include MCP tools whose meta.source is in this list; omit to default to ['mcp_tools'].", + advance=True, + ), } @classmethod @@ -360,7 +378,40 @@ class McpRemoteConfig(BaseConfig): else: raise ConfigError("timeout must be numeric", extend_path(path, "timeout")) - return cls(server=server, headers=headers, timeout=timeout, path=path) + cache_ttl_value = mapping.get("cache_ttl", 0.0) + if cache_ttl_value is None: + cache_ttl = 0.0 + elif isinstance(cache_ttl_value, (int, float)): + cache_ttl = float(cache_ttl_value) + else: + raise ConfigError("cache_ttl must be numeric", extend_path(path, "cache_ttl")) + + tool_sources_raw = mapping.get("tool_sources") + tool_sources: List[str] | None = None + if tool_sources_raw is not None: + entries = ensure_list(tool_sources_raw) + normalized: List[str] = [] + for idx, entry in enumerate(entries): + if not isinstance(entry, str): + raise ConfigError( + "tool_sources must be a list of strings", + extend_path(path, f"tool_sources[{idx}]"), + ) + value = entry.strip() + if value: + normalized.append(value) + tool_sources = normalized + else: + tool_sources = ["mcp_tools"] + + return cls( + server=server, + headers=headers, + timeout=timeout, + cache_ttl=cache_ttl, + tool_sources=tool_sources, + path=path, + ) def cache_key(self) -> str: payload = ( @@ -380,6 +431,7 @@ class McpLocalConfig(BaseConfig): inherit_env: bool = True startup_timeout: float = 10.0 wait_for_log: str | None = None + cache_ttl: float = 0.0 FIELD_SPECS = { "command": ConfigFieldSpec( @@ -438,6 +490,14 @@ class McpLocalConfig(BaseConfig): description="Regex that marks readiness when matched against stdout", advance=True, ), + "cache_ttl": ConfigFieldSpec( + name="cache_ttl", + display_name="Tool Cache TTL", + type_hint="float", + required=False, + description="Seconds to cache MCP tool list; 0 disables cache for hot updates", + advance=True, + ), } @classmethod @@ -474,6 +534,13 @@ class McpLocalConfig(BaseConfig): raise ConfigError("startup_timeout must be numeric", extend_path(path, "startup_timeout")) wait_for_log = optional_str(mapping, "wait_for_log", path) + cache_ttl_value = mapping.get("cache_ttl", 0.0) + if cache_ttl_value is None: + cache_ttl = 0.0 + elif isinstance(cache_ttl_value, (int, float)): + cache_ttl = float(cache_ttl_value) + else: + raise ConfigError("cache_ttl must be numeric", extend_path(path, "cache_ttl")) return cls( command=command, args=normalized_args, @@ -482,6 +549,7 @@ class McpLocalConfig(BaseConfig): inherit_env=bool(inherit_env), startup_timeout=startup_timeout, wait_for_log=wait_for_log, + cache_ttl=cache_ttl, path=path, ) diff --git a/frontend/src/utils/apiFunctions.js b/frontend/src/utils/apiFunctions.js index 47a977b9..fe536853 100755 --- a/frontend/src/utils/apiFunctions.js +++ b/frontend/src/utils/apiFunctions.js @@ -69,7 +69,7 @@ export async function postYaml(filename, content) { export async function updateYaml(filename, content) { try { const yamlFilename = addYamlSuffix(filename) - const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}`), { + const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}/update`), { method: 'PUT', headers: { 'Content-Type': 'application/json', @@ -197,7 +197,7 @@ export async function fetchWorkflowsWithDesc() { const filesWithDesc = await Promise.all( data.workflows.map(async (filename) => { try { - const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}`)) + const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}/get`)) const fileData = await response.json() return { name: filename, @@ -234,7 +234,7 @@ export async function fetchWorkflowsWithDesc() { // Fetch YAML file content export async function fetchWorkflowYAML(filename) { try { - const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}`)) + const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}/get`)) if (!response.ok) { throw new Error(`Failed to load YAML file: ${filename}, status: ${response.status}`) } @@ -250,7 +250,7 @@ export async function fetchWorkflowYAML(filename) { export async function fetchYaml(filename) { try { const yamlFilename = addYamlSuffix(filename) - const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}`)) + const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}/get`)) const data = await response.json().catch(() => ({})) diff --git a/runtime/node/agent/tool/tool_manager.py b/runtime/node/agent/tool/tool_manager.py index bf7276a8..864d6c6e 100755 --- a/runtime/node/agent/tool/tool_manager.py +++ b/runtime/node/agent/tool/tool_manager.py @@ -9,6 +9,7 @@ import logging import mimetypes import os import threading +import time from pathlib import Path from typing import Any, Dict, List, Mapping, Sequence @@ -35,13 +36,19 @@ class _FunctionManagerCacheEntry: auto_loaded: bool = False +@dataclass +class _McpToolCacheEntry: + tools: List[Any] + fetched_at: float + + class ToolManager: """Manage function tools for agent nodes.""" def __init__(self) -> None: self._functions_dir: Path = FUNCTION_CALLING_DIR self._function_managers: Dict[Path, _FunctionManagerCacheEntry] = {} - self._mcp_tool_cache: Dict[str, List[Any]] = {} + self._mcp_tool_cache: Dict[str, _McpToolCacheEntry] = {} self._mcp_stdio_clients: Dict[str, "_StdioClientWrapper"] = {} def _get_function_manager(self) -> FunctionManager: @@ -192,19 +199,25 @@ class ToolManager: def _build_mcp_remote_specs(self, config: McpRemoteConfig) -> List[ToolSpec]: cache_key = f"remote:{config.cache_key()}" - tools = self._mcp_tool_cache.get(cache_key) - if tools is None: - tools = asyncio.run( + tools = self._get_mcp_tools( + cache_key, + config.cache_ttl, + lambda: asyncio.run( self._fetch_mcp_tools_http( config.server, headers=config.headers, timeout=config.timeout, ) - ) - self._mcp_tool_cache[cache_key] = tools + ), + ) specs: List[ToolSpec] = [] + allowed_sources = {source for source in (config.tool_sources or []) if source} for tool in tools: + meta = getattr(tool, "meta", None) + source = meta.get("source") if isinstance(meta, Mapping) else None + if allowed_sources and source not in allowed_sources: + continue specs.append( ToolSpec( name=tool.name, @@ -221,10 +234,11 @@ class ToolManager: raise ValueError("MCP local configuration missing launch key") cache_key = f"stdio:{launch_key}" - tools = self._mcp_tool_cache.get(cache_key) - if tools is None: - tools = asyncio.run(self._fetch_mcp_tools_stdio(config, launch_key)) - self._mcp_tool_cache[cache_key] = tools + tools = self._get_mcp_tools( + cache_key, + config.cache_ttl, + lambda: asyncio.run(self._fetch_mcp_tools_stdio(config, launch_key)), + ) specs: List[ToolSpec] = [] for tool in tools: @@ -238,28 +252,25 @@ class ToolManager: ) return specs - def _execute_function_tool( + def _get_mcp_tools( self, - tool_name: str, - arguments: Dict[str, Any], - config: FunctionToolConfig, - tool_context: Dict[str, Any] | None = None, - ) -> Any: - mgr = self._get_function_manager() - if config.auto_load: - mgr.load_functions() - func = mgr.get_function(tool_name) - if func is None: - raise ValueError(f"Tool {tool_name} not found in {self._functions_dir}") + cache_key: str, + cache_ttl: float, + fetcher, + ) -> List[Any]: + entry = self._mcp_tool_cache.get(cache_key) + if entry and self._is_cache_fresh(entry.fetched_at, cache_ttl): + return entry.tools + tools = fetcher() + self._mcp_tool_cache[cache_key] = _McpToolCacheEntry(tools=tools, fetched_at=time.time()) + return tools + + @staticmethod + def _is_cache_fresh(fetched_at: float, cache_ttl: float) -> bool: + if cache_ttl <= 0: + return False + return (time.time() - fetched_at) < cache_ttl - call_args = dict(arguments or {}) - if ( - tool_context is not None - # and "_context" not in call_args - and self._function_accepts_context(func) - ): - call_args["_context"] = tool_context - return func(**call_args) def _function_accepts_context(self, func: Any) -> bool: try: diff --git a/server/app.py b/server/app.py index 6d974db4..ca9c7385 100755 --- a/server/app.py +++ b/server/app.py @@ -1,6 +1,8 @@ from fastapi import FastAPI - from server.bootstrap import init_app +from utils.env_loader import load_dotenv_file + +load_dotenv_file() app = FastAPI(title="DevAll Workflow Server", version="1.0.0") init_app(app) diff --git a/server/bootstrap.py b/server/bootstrap.py index b0f875e1..25ba6665 100755 --- a/server/bootstrap.py +++ b/server/bootstrap.py @@ -1,23 +1,18 @@ -"""Application bootstrap helpers for the FastAPI server.""" - from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from server import state from server.config_schema_router import router as config_schema_router -from server.routes import ALL_ROUTERS -from utils.error_handler import add_exception_handlers -from utils.middleware import add_middleware - - def init_app(app: FastAPI) -> None: """Apply shared middleware, routers, and global state to ``app``.""" + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + add_exception_handlers(app) - add_middleware(app) - - state.init_state() - - for router in ALL_ROUTERS: - app.include_router(router) - - app.include_router(config_schema_router) + add_middleware(app) \ No newline at end of file diff --git a/server/mcp_server.py b/server/mcp_server.py new file mode 100644 index 00000000..c7d31fb3 --- /dev/null +++ b/server/mcp_server.py @@ -0,0 +1,189 @@ +import importlib.util +import inspect +import random +import uuid +from pathlib import Path +from typing import Any, Iterable + +from fastmcp import FastMCP +from fastmcp.tools import FunctionTool +from starlette.requests import Request +from starlette.responses import JSONResponse + +# Repo root (for loading built-in function tools). +_REPO_ROOT = Path(__file__).resolve().parents[1] +_FUNCTION_CALLING_DIR = _REPO_ROOT / "functions" / "function_calling" + +# MCP server for production use (supports hot-updated tools). +mcp = FastMCP( + "DevAll MCP Server", + debug=True, +) + +_DYNAMIC_TOOLS_DIR = Path(__file__).parent / "dynamic_tools" +_DYNAMIC_TOOLS_DIR.mkdir(parents=True, exist_ok=True) + + + + +def _safe_tool_filename(filename: str) -> str: + name = Path(filename).name + if not name.endswith(".py"): + raise ValueError("filename must end with .py") + return name + + +def _load_module_from_path(path: Path) -> Any: + module_name = f"_dynamic_mcp_{path.stem}_{uuid.uuid4().hex}" + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise ValueError("failed to create module spec") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _select_functions(module: Any, allowlist: Iterable[str] | None) -> list[Any]: + allowset = {name.strip() for name in allowlist} if allowlist else None + selected = [] + for name, obj in inspect.getmembers(module, inspect.isfunction): + if name.startswith("_"): + continue + if getattr(obj, "__module__", None) != module.__name__: + continue + if allowset is not None and name not in allowset: + continue + selected.append(obj) + if allowset: + missing = sorted(allowset - {fn.__name__ for fn in selected}) + if missing: + raise ValueError(f"functions not found in module: {', '.join(missing)}") + return selected + + +def _register_functions(functions: Iterable[Any], *, replace: bool, source: str) -> list[str]: + added: list[str] = [] + for fn in functions: + tool = FunctionTool.from_function(fn, meta={"source": source}) + if replace: + try: + mcp.remove_tool(tool.name) + except Exception: + pass + mcp.add_tool(tool) + added.append(tool.name) + return added + + +def register_tools_from_payload(payload: dict) -> dict: + filename = payload.get("filename") + content = payload.get("content") + function_names = payload.get("functions") + replace = payload.get("replace", True) + + if not isinstance(filename, str) or not filename.strip(): + raise ValueError("filename is required") + if not isinstance(content, str) or not content.strip(): + raise ValueError("content is required") + if function_names is not None and not isinstance(function_names, list): + raise ValueError("functions must be a list of strings") + + safe_name = _safe_tool_filename(filename.strip()) + + file_path = _DYNAMIC_TOOLS_DIR / safe_name + file_path.write_text(content, encoding="utf-8") + + module = _load_module_from_path(file_path) + functions = _select_functions(module, function_names) + if not functions: + raise ValueError("no functions found to register") + + added = _register_functions(functions, replace=bool(replace), source="dynamic_tools") + + return {"status": "ok", "added": added, "file": str(file_path)} + + +@mcp.custom_route("/admin/tools/upload", methods=["POST"]) +async def upload_tool(request: Request) -> JSONResponse: + try: + payload = await request.json() + result = register_tools_from_payload(payload) + except ValueError as exc: + return JSONResponse({"error": str(exc)}, status_code=400) + except Exception as exc: + return JSONResponse({"error": f"failed to load tools: {exc}"}, status_code=400) + + return JSONResponse(result) + + +@mcp.custom_route("/admin/tools/reload", methods=["POST"]) +async def reload_tools(request: Request) -> JSONResponse: + replace = True + try: + payload = await request.json() + if isinstance(payload, dict) and "replace" in payload: + replace = bool(payload["replace"]) + except Exception: + replace = True + result = load_tools_from_directory(replace=replace) + return JSONResponse(result) + + +@mcp.custom_route("/admin/tools/list", methods=["GET"]) +async def list_tools(request: Request) -> JSONResponse: + tools = await mcp.get_tools() + items = [] + for key in sorted(tools.keys()): + tool = tools[key] + meta = tool.meta or {} + source = meta.get("source", "unknown") + call_methods = ["mcp_remote"] + if source == "function_calling": + call_methods.append("function") + payload = { + "key": key, + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + "enabled": tool.enabled, + "source": source, + "call_methods": call_methods, + } + if meta: + payload["meta"] = meta + if tool.output_schema is not None: + payload["output_schema"] = tool.output_schema + items.append(payload) + return JSONResponse({"tools": items}) + + +def load_tools_from_directory(*, replace: bool = True) -> dict: + added: list[str] = [] + errors: list[str] = [] + for directory in (_FUNCTION_CALLING_DIR, _DYNAMIC_TOOLS_DIR): + if not directory.exists(): + continue + source = "local_tools" if directory == _FUNCTION_CALLING_DIR else "mcp_tools" + for path in sorted(directory.glob("*.py")): + if path.name.startswith("_") or path.name == "__init__.py": + continue + try: + module = _load_module_from_path(path) + functions = _select_functions(module, None) + if not functions: + continue + added.extend(_register_functions(functions, replace=replace, source=source)) + except Exception as exc: + errors.append(f"{path.name}: {exc}") + return {"added": added, "errors": errors} + + +_bootstrap = load_tools_from_directory(replace=True) +if _bootstrap["errors"]: + print(f"Dynamic tool load errors: {_bootstrap['errors']}") + + +if __name__ == "__main__": + print("Starting DevAll MCP server...") + print("Run standalone with: fastmcp run server/mcp_server.py --transport streamable-http --port 8010") + mcp.run() diff --git a/server/models.py b/server/models.py index 273d1cec..2ac0de44 100755 --- a/server/models.py +++ b/server/models.py @@ -1,6 +1,6 @@ """Pydantic models shared across server routes.""" -from typing import List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, constr @@ -13,6 +13,15 @@ class WorkflowRequest(BaseModel): log_level: Literal["INFO", "DEBUG"] = "INFO" +class WorkflowRunRequest(BaseModel): + yaml_file: str + task_prompt: str + attachments: Optional[List[str]] = None + session_name: Optional[str] = None + variables: Optional[Dict[str, Any]] = None + log_level: Optional[Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]] = None + + class WorkflowUploadContentRequest(BaseModel): filename: str content: str diff --git a/server/routes/__init__.py b/server/routes/__init__.py index 32706ea7..c79bb05b 100755 --- a/server/routes/__init__.py +++ b/server/routes/__init__.py @@ -1,6 +1,6 @@ """Aggregates API routers.""" -from . import artifacts, batch, execute, health, sessions, uploads, vuegraphs, workflows, websocket +from . import artifacts, execute, execute_sync, health, sessions, uploads, vuegraphs, workflows, websocket, batch ALL_ROUTERS = [ health.router, @@ -11,7 +11,8 @@ ALL_ROUTERS = [ sessions.router, batch.router, execute.router, + execute_sync.router, websocket.router, ] -__all__ = ["ALL_ROUTERS"] +__all__ = ["ALL_ROUTERS"] \ No newline at end of file diff --git a/server/routes/execute_sync.py b/server/routes/execute_sync.py new file mode 100644 index 00000000..8426a1a8 --- /dev/null +++ b/server/routes/execute_sync.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import json +import queue +import threading +from datetime import datetime +from pathlib import Path +from typing import Any, Optional, Sequence, Union + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from starlette.concurrency import run_in_threadpool + +from check.check import load_config +from entity.enums import LogLevel +from entity.graph_config import GraphConfig +from entity.messages import Message +from runtime.bootstrap.schema import ensure_schema_registry_populated +from runtime.sdk import OUTPUT_ROOT, run_workflow +from server.models import WorkflowRunRequest +from server.settings import YAML_DIR +from utils.attachments import AttachmentStore +from utils.exceptions import ValidationError, WorkflowExecutionError +from utils.logger import WorkflowLogger +from utils.structured_logger import get_server_logger, LogType +from utils.task_input import TaskInputBuilder +from workflow.graph import GraphExecutor +from workflow.graph_context import GraphContext + +router = APIRouter() + +_SSE_CONTENT_TYPE = "text/event-stream" + + +def _normalize_session_name(yaml_path: Path, session_name: Optional[str]) -> str: + if session_name and session_name.strip(): + return session_name.strip() + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + return f"sdk_{yaml_path.stem}_{timestamp}" + + +def _resolve_yaml_path(yaml_file: Union[str, Path]) -> Path: + candidate = Path(yaml_file).expanduser() + if candidate.is_absolute(): + return candidate + if candidate.exists(): + return candidate + repo_root = Path(__file__).resolve().parents[2] + yaml_root = YAML_DIR if YAML_DIR.is_absolute() else (repo_root / YAML_DIR) + return (yaml_root / candidate).expanduser() + + +def _build_task_input( + graph_context: GraphContext, + prompt: str, + attachments: Sequence[Union[str, Path]], +) -> Union[str, list[Message]]: + if not attachments: + return prompt + + attachments_dir = graph_context.directory / "code_workspace" / "attachments" + attachments_dir.mkdir(parents=True, exist_ok=True) + store = AttachmentStore(attachments_dir) + builder = TaskInputBuilder(store) + normalized_paths = [str(Path(path).expanduser()) for path in attachments] + return builder.build_from_file_paths(prompt, normalized_paths) + + +def _run_workflow_with_logger( + *, + yaml_file: Union[str, Path], + task_prompt: str, + attachments: Optional[Sequence[Union[str, Path]]], + session_name: Optional[str], + variables: Optional[dict], + log_level: Optional[LogLevel], + log_callback, +) -> tuple[Optional[Message], dict[str, Any]]: + ensure_schema_registry_populated() + + yaml_path = _resolve_yaml_path(yaml_file) + if not yaml_path.exists(): + raise FileNotFoundError(f"YAML file not found: {yaml_path}") + + attachments = attachments or [] + if (not task_prompt or not task_prompt.strip()) and not attachments: + raise ValidationError( + "Task prompt cannot be empty", + details={"task_prompt_provided": bool(task_prompt)}, + ) + + design = load_config(yaml_path, vars_override=variables) + normalized_session = _normalize_session_name(yaml_path, session_name) + + graph_config = GraphConfig.from_definition( + design.graph, + name=normalized_session, + output_root=OUTPUT_ROOT, + source_path=str(yaml_path), + vars=design.vars, + ) + + if log_level: + graph_config.log_level = log_level + graph_config.definition.log_level = log_level + + graph_context = GraphContext(config=graph_config) + task_input = _build_task_input(graph_context, task_prompt, attachments) + + class _StreamingWorkflowLogger(WorkflowLogger): + def add_log(self, *args, **kwargs): + entry = super().add_log(*args, **kwargs) + if entry: + payload = entry.to_dict() + payload.pop("details", None) + log_callback("log", payload) + return entry + + class _StreamingExecutor(GraphExecutor): + def _create_logger(self) -> WorkflowLogger: + level = log_level or self.graph.log_level + return _StreamingWorkflowLogger( + self.graph.name, + level, + use_structured_logging=True, + log_to_console=False, + ) + + executor = _StreamingExecutor(graph_context, session_id=normalized_session) + executor._execute(task_input) + final_message = executor.get_final_output_message() + + logger = executor.log_manager.get_logger() if executor.log_manager else None + log_id = logger.workflow_id if logger else None + token_usage = executor.token_tracker.get_token_usage() if executor.token_tracker else None + + meta = { + "session_name": normalized_session, + "yaml_file": str(yaml_path), + "log_id": log_id, + "token_usage": token_usage, + "output_dir": graph_context.directory, + } + return final_message, meta + + +def _sse_event(event_type: str, data: Any) -> str: + payload = json.dumps(data, ensure_ascii=False, default=str) + return f"event: {event_type}\ndata: {payload}\n\n" + + +@router.post("/api/workflow/run") +async def run_workflow_sync(request: WorkflowRunRequest, http_request: Request): + try: + resolved_log_level: Optional[LogLevel] = None + if request.log_level: + resolved_log_level = LogLevel(request.log_level) + except ValueError: + raise HTTPException( + status_code=400, + detail="log_level must be one of DEBUG, INFO, WARNING, ERROR, CRITICAL", + ) + + accepts_stream = _SSE_CONTENT_TYPE in (http_request.headers.get("accept") or "") + if not accepts_stream: + try: + result = await run_in_threadpool( + run_workflow, + request.yaml_file, + task_prompt=request.task_prompt, + attachments=request.attachments, + session_name=request.session_name, + variables=request.variables, + log_level=resolved_log_level, + ) + except FileNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) + except ValidationError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except Exception as exc: + logger = get_server_logger() + logger.log_exception(exc, "Failed to run workflow via sync API") + raise WorkflowExecutionError(f"Failed to run workflow: {exc}") + + final_message = result.final_message.text_content() if result.final_message else "" + meta = result.meta_info + + logger = get_server_logger() + logger.info( + "Workflow execution completed via sync API", + log_type=LogType.WORKFLOW, + session_id=meta.session_name, + yaml_path=meta.yaml_file, + ) + + return { + "status": "completed", + "final_message": final_message, + "token_usage": meta.token_usage, + "output_dir": str(meta.output_dir.resolve()), + } + + event_queue: queue.Queue[tuple[str, Any]] = queue.Queue() + done_event = threading.Event() + + def enqueue(event_type: str, data: Any) -> None: + event_queue.put((event_type, data)) + + def worker() -> None: + try: + enqueue( + "started", + {"yaml_file": request.yaml_file, "task_prompt": request.task_prompt}, + ) + final_message, meta = _run_workflow_with_logger( + yaml_file=request.yaml_file, + task_prompt=request.task_prompt, + attachments=request.attachments, + session_name=request.session_name, + variables=request.variables, + log_level=resolved_log_level, + log_callback=enqueue, + ) + enqueue( + "completed", + { + "status": "completed", + "final_message": final_message.text_content() if final_message else "", + "token_usage": meta["token_usage"], + "output_dir": str(meta["output_dir"].resolve()), + }, + ) + except (FileNotFoundError, ValidationError) as exc: + enqueue("error", {"message": str(exc)}) + except Exception as exc: + logger = get_server_logger() + logger.log_exception(exc, "Failed to run workflow via streaming API") + enqueue("error", {"message": f"Failed to run workflow: {exc}"}) + finally: + done_event.set() + + threading.Thread(target=worker, daemon=True).start() + + async def stream(): + while True: + try: + event_type, data = event_queue.get(timeout=0.1) + yield _sse_event(event_type, data) + except queue.Empty: + if done_event.is_set(): + break + + return StreamingResponse(stream(), media_type=_SSE_CONTENT_TYPE) \ No newline at end of file diff --git a/server/routes/workflows.py b/server/routes/workflows.py index 597ea5ef..f0ff4dd1 100755 --- a/server/routes/workflows.py +++ b/server/routes/workflows.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, HTTPException +from typing import Any from server.models import ( WorkflowCopyRequest, @@ -67,6 +68,91 @@ async def list_workflows(): return {"workflows": [file.name for file in YAML_DIR.glob("*.yaml")]} +@router.get("/api/workflows/{filename}/args") +async def get_workflow_args(filename: str): + print(str) + try: + safe_filename = validate_workflow_filename(filename, require_yaml_extension=True) + print(safe_filename) + file_path = YAML_DIR / safe_filename + + if not file_path.exists() or not file_path.is_file(): + raise ResourceNotFoundError( + "Workflow file not found", + resource_type="workflow", + resource_id=safe_filename, + ) + + # Load and validate YAML content + raw_content = file_path.read_text(encoding="utf-8") + _, yaml_content = validate_workflow_content(safe_filename, raw_content) + + args: list[dict[str, Any]] = [] + if isinstance(yaml_content, dict): + graph = yaml_content.get("graph") or {} + if isinstance(graph, dict): + raw_args = graph.get("args") or [] + if isinstance(raw_args, list): + if len(raw_args) == 0: + raise ResourceNotFoundError( + "Workflow file does not have args", + resource_type="workflow", + resource_id=safe_filename, + ) + for item in raw_args: + # Each item is expected to be like: { arg_name: [ {key: value}, ... ] } + if not isinstance(item, dict) or len(item) != 1: + continue + (arg_name, spec_list), = item.items() + if not isinstance(arg_name, str): + continue + + arg_info: dict[str, Any] = {"name": arg_name} + if isinstance(spec_list, list): + for spec in spec_list: + if isinstance(spec, dict): + for key, value in spec.items(): + # Later entries override earlier ones if duplicated + arg_info[str(key)] = value + args.append(arg_info) + + logger = get_server_logger() + logger.info( + "Workflow args retrieved", + log_type=LogType.WORKFLOW, + filename=safe_filename, + args_count=len(args), + ) + + return {"args": args} + except ValidationError as exc: + # 参数或文件名等校验错误 + raise HTTPException( + status_code=400, + detail={"message": str(exc)}, + ) + except SecurityError as exc: + # 安全相关错误(例如路径遍历) + raise HTTPException( + status_code=400, + detail={"message": str(exc)}, + ) + except ResourceNotFoundError as exc: + # 文件不存在 + raise HTTPException( + status_code=404, + detail={"message": str(exc)}, + ) + except Exception as exc: + logger = get_server_logger() + logger.log_exception(exc, f"Unexpected error retrieving workflow args: {filename}") + # 兜底错误 + raise HTTPException( + status_code=500, + detail={"message": f"Failed to retrieve workflow args: {exc}"}, + ) + + @router.post("/api/workflows/upload/content") async def upload_workflow_content(request: WorkflowUploadContentRequest): return _persist_workflow_from_content( @@ -78,7 +164,7 @@ async def upload_workflow_content(request: WorkflowUploadContentRequest): ) -@router.put("/api/workflows/{filename}") +@router.put("/api/workflows/{filename}/update") async def update_workflow_content(filename: str, request: WorkflowUpdateContentRequest): return _persist_workflow_from_content( filename, @@ -89,7 +175,7 @@ async def update_workflow_content(filename: str, request: WorkflowUpdateContentR ) -@router.delete("/api/workflows/{filename}") +@router.delete("/api/workflows/{filename}/delete") async def delete_workflow(filename: str): try: safe_filename = validate_workflow_filename(filename, require_yaml_extension=True) @@ -180,7 +266,7 @@ async def copy_workflow_file(filename: str, request: WorkflowCopyRequest): raise WorkflowExecutionError(f"Failed to copy workflow: {exc}") -@router.get("/api/workflows/{filename}") +@router.get("/api/workflows/{filename}/get") async def get_workflow_raw_content(filename: str): try: safe_filename = validate_workflow_filename(filename, require_yaml_extension=True) @@ -208,4 +294,4 @@ async def get_workflow_raw_content(filename: str): except Exception as exc: logger = get_server_logger() logger.log_exception(exc, f"Unexpected error retrieving workflow: {filename}") - raise WorkflowExecutionError(f"Failed to retrieve workflow: {exc}") + raise WorkflowExecutionError(f"Failed to retrieve workflow: {exc}") \ No newline at end of file diff --git a/server/services/websocket_manager.py b/server/services/websocket_manager.py index 78cc5142..049030e6 100755 --- a/server/services/websocket_manager.py +++ b/server/services/websocket_manager.py @@ -49,8 +49,6 @@ class WebSocketManager: ): self.active_connections: Dict[str, WebSocket] = {} self.connection_timestamps: Dict[str, float] = {} - self.send_locks: Dict[str, asyncio.Lock] = {} - self.loop: asyncio.AbstractEventLoop | None = None self.session_store = session_store or WorkflowSessionStore() self.session_controller = session_controller or SessionExecutionController(self.session_store) self.attachment_service = attachment_service or AttachmentService() @@ -67,16 +65,10 @@ class WebSocketManager: async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str: await websocket.accept() - if self.loop is None: - try: - self.loop = asyncio.get_running_loop() - except RuntimeError: - self.loop = None if not session_id: session_id = str(uuid.uuid4()) self.active_connections[session_id] = websocket self.connection_timestamps[session_id] = time.time() - self.send_locks[session_id] = asyncio.Lock() logging.info("WebSocket connected: %s", session_id) await self.send_message( session_id, @@ -98,8 +90,6 @@ class WebSocketManager: del self.active_connections[session_id] if session_id in self.connection_timestamps: del self.connection_timestamps[session_id] - if session_id in self.send_locks: - del self.send_locks[session_id] self.session_controller.cleanup_session(session_id) remaining_session = self.session_store.get_session(session_id) if remaining_session and remaining_session.executor is None: @@ -111,12 +101,7 @@ class WebSocketManager: if session_id in self.active_connections: websocket = self.active_connections[session_id] try: - lock = self.send_locks.get(session_id) - if lock is None: - await websocket.send_text(_encode_ws_message(message)) - else: - async with lock: - await websocket.send_text(_encode_ws_message(message)) + await websocket.send_text(_encode_ws_message(message)) except Exception as exc: traceback.print_exc() logging.error("Failed to send message to %s: %s", session_id, exc) @@ -130,13 +115,7 @@ class WebSocketManager: else: asyncio.run(self.send_message(session_id, message)) except RuntimeError: - if self.loop and self.loop.is_running(): - asyncio.run_coroutine_threadsafe( - self.send_message(session_id, message), - self.loop, - ) - else: - asyncio.run(self.send_message(session_id, message)) + asyncio.run(self.send_message(session_id, message)) async def broadcast(self, message: Dict[str, Any]) -> None: for session_id in list(self.active_connections.keys()):