mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2026-04-25 19:28:09 +00:00
[feat] add server endpoint
This commit is contained in:
parent
c7bb2df9e5
commit
9381abd96f
@ -312,6 +312,8 @@ class McpRemoteConfig(BaseConfig):
|
||||
server: str
|
||||
headers: Dict[str, str] = field(default_factory=dict)
|
||||
timeout: float | None = None
|
||||
cache_ttl: float = 0.0
|
||||
tool_sources: List[str] | None = None
|
||||
|
||||
FIELD_SPECS = {
|
||||
"server": ConfigFieldSpec(
|
||||
@ -337,6 +339,22 @@ class McpRemoteConfig(BaseConfig):
|
||||
description="Per-request timeout in seconds",
|
||||
advance=True,
|
||||
),
|
||||
"cache_ttl": ConfigFieldSpec(
|
||||
name="cache_ttl",
|
||||
display_name="Tool Cache TTL",
|
||||
type_hint="float",
|
||||
required=False,
|
||||
description="Seconds to cache MCP tool list; 0 disables cache for hot updates",
|
||||
advance=True,
|
||||
),
|
||||
"tool_sources": ConfigFieldSpec(
|
||||
name="tool_sources",
|
||||
display_name="Tool Sources Filter",
|
||||
type_hint="list[str]",
|
||||
required=False,
|
||||
description="Only include MCP tools whose meta.source is in this list; omit to default to ['mcp_tools'].",
|
||||
advance=True,
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -360,7 +378,40 @@ class McpRemoteConfig(BaseConfig):
|
||||
else:
|
||||
raise ConfigError("timeout must be numeric", extend_path(path, "timeout"))
|
||||
|
||||
return cls(server=server, headers=headers, timeout=timeout, path=path)
|
||||
cache_ttl_value = mapping.get("cache_ttl", 0.0)
|
||||
if cache_ttl_value is None:
|
||||
cache_ttl = 0.0
|
||||
elif isinstance(cache_ttl_value, (int, float)):
|
||||
cache_ttl = float(cache_ttl_value)
|
||||
else:
|
||||
raise ConfigError("cache_ttl must be numeric", extend_path(path, "cache_ttl"))
|
||||
|
||||
tool_sources_raw = mapping.get("tool_sources")
|
||||
tool_sources: List[str] | None = None
|
||||
if tool_sources_raw is not None:
|
||||
entries = ensure_list(tool_sources_raw)
|
||||
normalized: List[str] = []
|
||||
for idx, entry in enumerate(entries):
|
||||
if not isinstance(entry, str):
|
||||
raise ConfigError(
|
||||
"tool_sources must be a list of strings",
|
||||
extend_path(path, f"tool_sources[{idx}]"),
|
||||
)
|
||||
value = entry.strip()
|
||||
if value:
|
||||
normalized.append(value)
|
||||
tool_sources = normalized
|
||||
else:
|
||||
tool_sources = ["mcp_tools"]
|
||||
|
||||
return cls(
|
||||
server=server,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
cache_ttl=cache_ttl,
|
||||
tool_sources=tool_sources,
|
||||
path=path,
|
||||
)
|
||||
|
||||
def cache_key(self) -> str:
|
||||
payload = (
|
||||
@ -380,6 +431,7 @@ class McpLocalConfig(BaseConfig):
|
||||
inherit_env: bool = True
|
||||
startup_timeout: float = 10.0
|
||||
wait_for_log: str | None = None
|
||||
cache_ttl: float = 0.0
|
||||
|
||||
FIELD_SPECS = {
|
||||
"command": ConfigFieldSpec(
|
||||
@ -438,6 +490,14 @@ class McpLocalConfig(BaseConfig):
|
||||
description="Regex that marks readiness when matched against stdout",
|
||||
advance=True,
|
||||
),
|
||||
"cache_ttl": ConfigFieldSpec(
|
||||
name="cache_ttl",
|
||||
display_name="Tool Cache TTL",
|
||||
type_hint="float",
|
||||
required=False,
|
||||
description="Seconds to cache MCP tool list; 0 disables cache for hot updates",
|
||||
advance=True,
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -474,6 +534,13 @@ class McpLocalConfig(BaseConfig):
|
||||
raise ConfigError("startup_timeout must be numeric", extend_path(path, "startup_timeout"))
|
||||
|
||||
wait_for_log = optional_str(mapping, "wait_for_log", path)
|
||||
cache_ttl_value = mapping.get("cache_ttl", 0.0)
|
||||
if cache_ttl_value is None:
|
||||
cache_ttl = 0.0
|
||||
elif isinstance(cache_ttl_value, (int, float)):
|
||||
cache_ttl = float(cache_ttl_value)
|
||||
else:
|
||||
raise ConfigError("cache_ttl must be numeric", extend_path(path, "cache_ttl"))
|
||||
return cls(
|
||||
command=command,
|
||||
args=normalized_args,
|
||||
@ -482,6 +549,7 @@ class McpLocalConfig(BaseConfig):
|
||||
inherit_env=bool(inherit_env),
|
||||
startup_timeout=startup_timeout,
|
||||
wait_for_log=wait_for_log,
|
||||
cache_ttl=cache_ttl,
|
||||
path=path,
|
||||
)
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ export async function postYaml(filename, content) {
|
||||
export async function updateYaml(filename, content) {
|
||||
try {
|
||||
const yamlFilename = addYamlSuffix(filename)
|
||||
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}`), {
|
||||
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}/update`), {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@ -197,7 +197,7 @@ export async function fetchWorkflowsWithDesc() {
|
||||
const filesWithDesc = await Promise.all(
|
||||
data.workflows.map(async (filename) => {
|
||||
try {
|
||||
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}`))
|
||||
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}/get`))
|
||||
const fileData = await response.json()
|
||||
return {
|
||||
name: filename,
|
||||
@ -234,7 +234,7 @@ export async function fetchWorkflowsWithDesc() {
|
||||
// Fetch YAML file content
|
||||
export async function fetchWorkflowYAML(filename) {
|
||||
try {
|
||||
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}`))
|
||||
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(filename)}/get`))
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to load YAML file: ${filename}, status: ${response.status}`)
|
||||
}
|
||||
@ -250,7 +250,7 @@ export async function fetchWorkflowYAML(filename) {
|
||||
export async function fetchYaml(filename) {
|
||||
try {
|
||||
const yamlFilename = addYamlSuffix(filename)
|
||||
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}`))
|
||||
const response = await fetch(apiUrl(`/api/workflows/${encodeURIComponent(yamlFilename)}/get`))
|
||||
|
||||
const data = await response.json().catch(() => ({}))
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, Sequence
|
||||
|
||||
@ -35,13 +36,19 @@ class _FunctionManagerCacheEntry:
|
||||
auto_loaded: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _McpToolCacheEntry:
|
||||
tools: List[Any]
|
||||
fetched_at: float
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manage function tools for agent nodes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._functions_dir: Path = FUNCTION_CALLING_DIR
|
||||
self._function_managers: Dict[Path, _FunctionManagerCacheEntry] = {}
|
||||
self._mcp_tool_cache: Dict[str, List[Any]] = {}
|
||||
self._mcp_tool_cache: Dict[str, _McpToolCacheEntry] = {}
|
||||
self._mcp_stdio_clients: Dict[str, "_StdioClientWrapper"] = {}
|
||||
|
||||
def _get_function_manager(self) -> FunctionManager:
|
||||
@ -192,19 +199,25 @@ class ToolManager:
|
||||
|
||||
def _build_mcp_remote_specs(self, config: McpRemoteConfig) -> List[ToolSpec]:
|
||||
cache_key = f"remote:{config.cache_key()}"
|
||||
tools = self._mcp_tool_cache.get(cache_key)
|
||||
if tools is None:
|
||||
tools = asyncio.run(
|
||||
tools = self._get_mcp_tools(
|
||||
cache_key,
|
||||
config.cache_ttl,
|
||||
lambda: asyncio.run(
|
||||
self._fetch_mcp_tools_http(
|
||||
config.server,
|
||||
headers=config.headers,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
)
|
||||
self._mcp_tool_cache[cache_key] = tools
|
||||
),
|
||||
)
|
||||
|
||||
specs: List[ToolSpec] = []
|
||||
allowed_sources = {source for source in (config.tool_sources or []) if source}
|
||||
for tool in tools:
|
||||
meta = getattr(tool, "meta", None)
|
||||
source = meta.get("source") if isinstance(meta, Mapping) else None
|
||||
if allowed_sources and source not in allowed_sources:
|
||||
continue
|
||||
specs.append(
|
||||
ToolSpec(
|
||||
name=tool.name,
|
||||
@ -221,10 +234,11 @@ class ToolManager:
|
||||
raise ValueError("MCP local configuration missing launch key")
|
||||
|
||||
cache_key = f"stdio:{launch_key}"
|
||||
tools = self._mcp_tool_cache.get(cache_key)
|
||||
if tools is None:
|
||||
tools = asyncio.run(self._fetch_mcp_tools_stdio(config, launch_key))
|
||||
self._mcp_tool_cache[cache_key] = tools
|
||||
tools = self._get_mcp_tools(
|
||||
cache_key,
|
||||
config.cache_ttl,
|
||||
lambda: asyncio.run(self._fetch_mcp_tools_stdio(config, launch_key)),
|
||||
)
|
||||
|
||||
specs: List[ToolSpec] = []
|
||||
for tool in tools:
|
||||
@ -238,28 +252,25 @@ class ToolManager:
|
||||
)
|
||||
return specs
|
||||
|
||||
def _execute_function_tool(
|
||||
def _get_mcp_tools(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
config: FunctionToolConfig,
|
||||
tool_context: Dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
mgr = self._get_function_manager()
|
||||
if config.auto_load:
|
||||
mgr.load_functions()
|
||||
func = mgr.get_function(tool_name)
|
||||
if func is None:
|
||||
raise ValueError(f"Tool {tool_name} not found in {self._functions_dir}")
|
||||
cache_key: str,
|
||||
cache_ttl: float,
|
||||
fetcher,
|
||||
) -> List[Any]:
|
||||
entry = self._mcp_tool_cache.get(cache_key)
|
||||
if entry and self._is_cache_fresh(entry.fetched_at, cache_ttl):
|
||||
return entry.tools
|
||||
tools = fetcher()
|
||||
self._mcp_tool_cache[cache_key] = _McpToolCacheEntry(tools=tools, fetched_at=time.time())
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
def _is_cache_fresh(fetched_at: float, cache_ttl: float) -> bool:
|
||||
if cache_ttl <= 0:
|
||||
return False
|
||||
return (time.time() - fetched_at) < cache_ttl
|
||||
|
||||
call_args = dict(arguments or {})
|
||||
if (
|
||||
tool_context is not None
|
||||
# and "_context" not in call_args
|
||||
and self._function_accepts_context(func)
|
||||
):
|
||||
call_args["_context"] = tool_context
|
||||
return func(**call_args)
|
||||
|
||||
def _function_accepts_context(self, func: Any) -> bool:
|
||||
try:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from server.bootstrap import init_app
|
||||
from utils.env_loader import load_dotenv_file
|
||||
|
||||
load_dotenv_file()
|
||||
|
||||
app = FastAPI(title="DevAll Workflow Server", version="1.0.0")
|
||||
init_app(app)
|
||||
|
||||
@ -1,23 +1,18 @@
|
||||
"""Application bootstrap helpers for the FastAPI server."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from server import state
|
||||
from server.config_schema_router import router as config_schema_router
|
||||
from server.routes import ALL_ROUTERS
|
||||
from utils.error_handler import add_exception_handlers
|
||||
from utils.middleware import add_middleware
|
||||
|
||||
|
||||
def init_app(app: FastAPI) -> None:
|
||||
"""Apply shared middleware, routers, and global state to ``app``."""
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
add_exception_handlers(app)
|
||||
add_middleware(app)
|
||||
|
||||
state.init_state()
|
||||
|
||||
for router in ALL_ROUTERS:
|
||||
app.include_router(router)
|
||||
|
||||
app.include_router(config_schema_router)
|
||||
add_middleware(app)
|
||||
189
server/mcp_server.py
Normal file
189
server/mcp_server.py
Normal file
@ -0,0 +1,189 @@
|
||||
import importlib.util
|
||||
import inspect
|
||||
import random
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.tools import FunctionTool
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
# Repo root (for loading built-in function tools).
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_FUNCTION_CALLING_DIR = _REPO_ROOT / "functions" / "function_calling"
|
||||
|
||||
# MCP server for production use (supports hot-updated tools).
|
||||
mcp = FastMCP(
|
||||
"DevAll MCP Server",
|
||||
debug=True,
|
||||
)
|
||||
|
||||
_DYNAMIC_TOOLS_DIR = Path(__file__).parent / "dynamic_tools"
|
||||
_DYNAMIC_TOOLS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
|
||||
|
||||
def _safe_tool_filename(filename: str) -> str:
|
||||
name = Path(filename).name
|
||||
if not name.endswith(".py"):
|
||||
raise ValueError("filename must end with .py")
|
||||
return name
|
||||
|
||||
|
||||
def _load_module_from_path(path: Path) -> Any:
|
||||
module_name = f"_dynamic_mcp_{path.stem}_{uuid.uuid4().hex}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ValueError("failed to create module spec")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _select_functions(module: Any, allowlist: Iterable[str] | None) -> list[Any]:
|
||||
allowset = {name.strip() for name in allowlist} if allowlist else None
|
||||
selected = []
|
||||
for name, obj in inspect.getmembers(module, inspect.isfunction):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
if getattr(obj, "__module__", None) != module.__name__:
|
||||
continue
|
||||
if allowset is not None and name not in allowset:
|
||||
continue
|
||||
selected.append(obj)
|
||||
if allowset:
|
||||
missing = sorted(allowset - {fn.__name__ for fn in selected})
|
||||
if missing:
|
||||
raise ValueError(f"functions not found in module: {', '.join(missing)}")
|
||||
return selected
|
||||
|
||||
|
||||
def _register_functions(functions: Iterable[Any], *, replace: bool, source: str) -> list[str]:
|
||||
added: list[str] = []
|
||||
for fn in functions:
|
||||
tool = FunctionTool.from_function(fn, meta={"source": source})
|
||||
if replace:
|
||||
try:
|
||||
mcp.remove_tool(tool.name)
|
||||
except Exception:
|
||||
pass
|
||||
mcp.add_tool(tool)
|
||||
added.append(tool.name)
|
||||
return added
|
||||
|
||||
|
||||
def register_tools_from_payload(payload: dict) -> dict:
|
||||
filename = payload.get("filename")
|
||||
content = payload.get("content")
|
||||
function_names = payload.get("functions")
|
||||
replace = payload.get("replace", True)
|
||||
|
||||
if not isinstance(filename, str) or not filename.strip():
|
||||
raise ValueError("filename is required")
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
raise ValueError("content is required")
|
||||
if function_names is not None and not isinstance(function_names, list):
|
||||
raise ValueError("functions must be a list of strings")
|
||||
|
||||
safe_name = _safe_tool_filename(filename.strip())
|
||||
|
||||
file_path = _DYNAMIC_TOOLS_DIR / safe_name
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
|
||||
module = _load_module_from_path(file_path)
|
||||
functions = _select_functions(module, function_names)
|
||||
if not functions:
|
||||
raise ValueError("no functions found to register")
|
||||
|
||||
added = _register_functions(functions, replace=bool(replace), source="dynamic_tools")
|
||||
|
||||
return {"status": "ok", "added": added, "file": str(file_path)}
|
||||
|
||||
|
||||
@mcp.custom_route("/admin/tools/upload", methods=["POST"])
|
||||
async def upload_tool(request: Request) -> JSONResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = register_tools_from_payload(payload)
|
||||
except ValueError as exc:
|
||||
return JSONResponse({"error": str(exc)}, status_code=400)
|
||||
except Exception as exc:
|
||||
return JSONResponse({"error": f"failed to load tools: {exc}"}, status_code=400)
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
|
||||
@mcp.custom_route("/admin/tools/reload", methods=["POST"])
|
||||
async def reload_tools(request: Request) -> JSONResponse:
|
||||
replace = True
|
||||
try:
|
||||
payload = await request.json()
|
||||
if isinstance(payload, dict) and "replace" in payload:
|
||||
replace = bool(payload["replace"])
|
||||
except Exception:
|
||||
replace = True
|
||||
result = load_tools_from_directory(replace=replace)
|
||||
return JSONResponse(result)
|
||||
|
||||
|
||||
@mcp.custom_route("/admin/tools/list", methods=["GET"])
|
||||
async def list_tools(request: Request) -> JSONResponse:
|
||||
tools = await mcp.get_tools()
|
||||
items = []
|
||||
for key in sorted(tools.keys()):
|
||||
tool = tools[key]
|
||||
meta = tool.meta or {}
|
||||
source = meta.get("source", "unknown")
|
||||
call_methods = ["mcp_remote"]
|
||||
if source == "function_calling":
|
||||
call_methods.append("function")
|
||||
payload = {
|
||||
"key": key,
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
"enabled": tool.enabled,
|
||||
"source": source,
|
||||
"call_methods": call_methods,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
if tool.output_schema is not None:
|
||||
payload["output_schema"] = tool.output_schema
|
||||
items.append(payload)
|
||||
return JSONResponse({"tools": items})
|
||||
|
||||
|
||||
def load_tools_from_directory(*, replace: bool = True) -> dict:
|
||||
added: list[str] = []
|
||||
errors: list[str] = []
|
||||
for directory in (_FUNCTION_CALLING_DIR, _DYNAMIC_TOOLS_DIR):
|
||||
if not directory.exists():
|
||||
continue
|
||||
source = "local_tools" if directory == _FUNCTION_CALLING_DIR else "mcp_tools"
|
||||
for path in sorted(directory.glob("*.py")):
|
||||
if path.name.startswith("_") or path.name == "__init__.py":
|
||||
continue
|
||||
try:
|
||||
module = _load_module_from_path(path)
|
||||
functions = _select_functions(module, None)
|
||||
if not functions:
|
||||
continue
|
||||
added.extend(_register_functions(functions, replace=replace, source=source))
|
||||
except Exception as exc:
|
||||
errors.append(f"{path.name}: {exc}")
|
||||
return {"added": added, "errors": errors}
|
||||
|
||||
|
||||
_bootstrap = load_tools_from_directory(replace=True)
|
||||
if _bootstrap["errors"]:
|
||||
print(f"Dynamic tool load errors: {_bootstrap['errors']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting DevAll MCP server...")
|
||||
print("Run standalone with: fastmcp run server/mcp_server.py --transport streamable-http --port 8010")
|
||||
mcp.run()
|
||||
@ -1,6 +1,6 @@
|
||||
"""Pydantic models shared across server routes."""
|
||||
|
||||
from typing import List, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, constr
|
||||
|
||||
@ -13,6 +13,15 @@ class WorkflowRequest(BaseModel):
|
||||
log_level: Literal["INFO", "DEBUG"] = "INFO"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
yaml_file: str
|
||||
task_prompt: str
|
||||
attachments: Optional[List[str]] = None
|
||||
session_name: Optional[str] = None
|
||||
variables: Optional[Dict[str, Any]] = None
|
||||
log_level: Optional[Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]] = None
|
||||
|
||||
|
||||
class WorkflowUploadContentRequest(BaseModel):
|
||||
filename: str
|
||||
content: str
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Aggregates API routers."""
|
||||
|
||||
from . import artifacts, batch, execute, health, sessions, uploads, vuegraphs, workflows, websocket
|
||||
from . import artifacts, execute, execute_sync, health, sessions, uploads, vuegraphs, workflows, websocket, batch
|
||||
|
||||
ALL_ROUTERS = [
|
||||
health.router,
|
||||
@ -11,7 +11,8 @@ ALL_ROUTERS = [
|
||||
sessions.router,
|
||||
batch.router,
|
||||
execute.router,
|
||||
execute_sync.router,
|
||||
websocket.router,
|
||||
]
|
||||
|
||||
__all__ = ["ALL_ROUTERS"]
|
||||
__all__ = ["ALL_ROUTERS"]
|
||||
253
server/routes/execute_sync.py
Normal file
253
server/routes/execute_sync.py
Normal file
@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from check.check import load_config
|
||||
from entity.enums import LogLevel
|
||||
from entity.graph_config import GraphConfig
|
||||
from entity.messages import Message
|
||||
from runtime.bootstrap.schema import ensure_schema_registry_populated
|
||||
from runtime.sdk import OUTPUT_ROOT, run_workflow
|
||||
from server.models import WorkflowRunRequest
|
||||
from server.settings import YAML_DIR
|
||||
from utils.attachments import AttachmentStore
|
||||
from utils.exceptions import ValidationError, WorkflowExecutionError
|
||||
from utils.logger import WorkflowLogger
|
||||
from utils.structured_logger import get_server_logger, LogType
|
||||
from utils.task_input import TaskInputBuilder
|
||||
from workflow.graph import GraphExecutor
|
||||
from workflow.graph_context import GraphContext
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_SSE_CONTENT_TYPE = "text/event-stream"
|
||||
|
||||
|
||||
def _normalize_session_name(yaml_path: Path, session_name: Optional[str]) -> str:
|
||||
if session_name and session_name.strip():
|
||||
return session_name.strip()
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
return f"sdk_{yaml_path.stem}_{timestamp}"
|
||||
|
||||
|
||||
def _resolve_yaml_path(yaml_file: Union[str, Path]) -> Path:
|
||||
candidate = Path(yaml_file).expanduser()
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
yaml_root = YAML_DIR if YAML_DIR.is_absolute() else (repo_root / YAML_DIR)
|
||||
return (yaml_root / candidate).expanduser()
|
||||
|
||||
|
||||
def _build_task_input(
|
||||
graph_context: GraphContext,
|
||||
prompt: str,
|
||||
attachments: Sequence[Union[str, Path]],
|
||||
) -> Union[str, list[Message]]:
|
||||
if not attachments:
|
||||
return prompt
|
||||
|
||||
attachments_dir = graph_context.directory / "code_workspace" / "attachments"
|
||||
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||
store = AttachmentStore(attachments_dir)
|
||||
builder = TaskInputBuilder(store)
|
||||
normalized_paths = [str(Path(path).expanduser()) for path in attachments]
|
||||
return builder.build_from_file_paths(prompt, normalized_paths)
|
||||
|
||||
|
||||
def _run_workflow_with_logger(
|
||||
*,
|
||||
yaml_file: Union[str, Path],
|
||||
task_prompt: str,
|
||||
attachments: Optional[Sequence[Union[str, Path]]],
|
||||
session_name: Optional[str],
|
||||
variables: Optional[dict],
|
||||
log_level: Optional[LogLevel],
|
||||
log_callback,
|
||||
) -> tuple[Optional[Message], dict[str, Any]]:
|
||||
ensure_schema_registry_populated()
|
||||
|
||||
yaml_path = _resolve_yaml_path(yaml_file)
|
||||
if not yaml_path.exists():
|
||||
raise FileNotFoundError(f"YAML file not found: {yaml_path}")
|
||||
|
||||
attachments = attachments or []
|
||||
if (not task_prompt or not task_prompt.strip()) and not attachments:
|
||||
raise ValidationError(
|
||||
"Task prompt cannot be empty",
|
||||
details={"task_prompt_provided": bool(task_prompt)},
|
||||
)
|
||||
|
||||
design = load_config(yaml_path, vars_override=variables)
|
||||
normalized_session = _normalize_session_name(yaml_path, session_name)
|
||||
|
||||
graph_config = GraphConfig.from_definition(
|
||||
design.graph,
|
||||
name=normalized_session,
|
||||
output_root=OUTPUT_ROOT,
|
||||
source_path=str(yaml_path),
|
||||
vars=design.vars,
|
||||
)
|
||||
|
||||
if log_level:
|
||||
graph_config.log_level = log_level
|
||||
graph_config.definition.log_level = log_level
|
||||
|
||||
graph_context = GraphContext(config=graph_config)
|
||||
task_input = _build_task_input(graph_context, task_prompt, attachments)
|
||||
|
||||
class _StreamingWorkflowLogger(WorkflowLogger):
|
||||
def add_log(self, *args, **kwargs):
|
||||
entry = super().add_log(*args, **kwargs)
|
||||
if entry:
|
||||
payload = entry.to_dict()
|
||||
payload.pop("details", None)
|
||||
log_callback("log", payload)
|
||||
return entry
|
||||
|
||||
class _StreamingExecutor(GraphExecutor):
|
||||
def _create_logger(self) -> WorkflowLogger:
|
||||
level = log_level or self.graph.log_level
|
||||
return _StreamingWorkflowLogger(
|
||||
self.graph.name,
|
||||
level,
|
||||
use_structured_logging=True,
|
||||
log_to_console=False,
|
||||
)
|
||||
|
||||
executor = _StreamingExecutor(graph_context, session_id=normalized_session)
|
||||
executor._execute(task_input)
|
||||
final_message = executor.get_final_output_message()
|
||||
|
||||
logger = executor.log_manager.get_logger() if executor.log_manager else None
|
||||
log_id = logger.workflow_id if logger else None
|
||||
token_usage = executor.token_tracker.get_token_usage() if executor.token_tracker else None
|
||||
|
||||
meta = {
|
||||
"session_name": normalized_session,
|
||||
"yaml_file": str(yaml_path),
|
||||
"log_id": log_id,
|
||||
"token_usage": token_usage,
|
||||
"output_dir": graph_context.directory,
|
||||
}
|
||||
return final_message, meta
|
||||
|
||||
|
||||
def _sse_event(event_type: str, data: Any) -> str:
|
||||
payload = json.dumps(data, ensure_ascii=False, default=str)
|
||||
return f"event: {event_type}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
@router.post("/api/workflow/run")
|
||||
async def run_workflow_sync(request: WorkflowRunRequest, http_request: Request):
|
||||
try:
|
||||
resolved_log_level: Optional[LogLevel] = None
|
||||
if request.log_level:
|
||||
resolved_log_level = LogLevel(request.log_level)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="log_level must be one of DEBUG, INFO, WARNING, ERROR, CRITICAL",
|
||||
)
|
||||
|
||||
accepts_stream = _SSE_CONTENT_TYPE in (http_request.headers.get("accept") or "")
|
||||
if not accepts_stream:
|
||||
try:
|
||||
result = await run_in_threadpool(
|
||||
run_workflow,
|
||||
request.yaml_file,
|
||||
task_prompt=request.task_prompt,
|
||||
attachments=request.attachments,
|
||||
session_name=request.session_name,
|
||||
variables=request.variables,
|
||||
log_level=resolved_log_level,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except Exception as exc:
|
||||
logger = get_server_logger()
|
||||
logger.log_exception(exc, "Failed to run workflow via sync API")
|
||||
raise WorkflowExecutionError(f"Failed to run workflow: {exc}")
|
||||
|
||||
final_message = result.final_message.text_content() if result.final_message else ""
|
||||
meta = result.meta_info
|
||||
|
||||
logger = get_server_logger()
|
||||
logger.info(
|
||||
"Workflow execution completed via sync API",
|
||||
log_type=LogType.WORKFLOW,
|
||||
session_id=meta.session_name,
|
||||
yaml_path=meta.yaml_file,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"final_message": final_message,
|
||||
"token_usage": meta.token_usage,
|
||||
"output_dir": str(meta.output_dir.resolve()),
|
||||
}
|
||||
|
||||
event_queue: queue.Queue[tuple[str, Any]] = queue.Queue()
|
||||
done_event = threading.Event()
|
||||
|
||||
def enqueue(event_type: str, data: Any) -> None:
|
||||
event_queue.put((event_type, data))
|
||||
|
||||
def worker() -> None:
|
||||
try:
|
||||
enqueue(
|
||||
"started",
|
||||
{"yaml_file": request.yaml_file, "task_prompt": request.task_prompt},
|
||||
)
|
||||
final_message, meta = _run_workflow_with_logger(
|
||||
yaml_file=request.yaml_file,
|
||||
task_prompt=request.task_prompt,
|
||||
attachments=request.attachments,
|
||||
session_name=request.session_name,
|
||||
variables=request.variables,
|
||||
log_level=resolved_log_level,
|
||||
log_callback=enqueue,
|
||||
)
|
||||
enqueue(
|
||||
"completed",
|
||||
{
|
||||
"status": "completed",
|
||||
"final_message": final_message.text_content() if final_message else "",
|
||||
"token_usage": meta["token_usage"],
|
||||
"output_dir": str(meta["output_dir"].resolve()),
|
||||
},
|
||||
)
|
||||
except (FileNotFoundError, ValidationError) as exc:
|
||||
enqueue("error", {"message": str(exc)})
|
||||
except Exception as exc:
|
||||
logger = get_server_logger()
|
||||
logger.log_exception(exc, "Failed to run workflow via streaming API")
|
||||
enqueue("error", {"message": f"Failed to run workflow: {exc}"})
|
||||
finally:
|
||||
done_event.set()
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
|
||||
async def stream():
|
||||
while True:
|
||||
try:
|
||||
event_type, data = event_queue.get(timeout=0.1)
|
||||
yield _sse_event(event_type, data)
|
||||
except queue.Empty:
|
||||
if done_event.is_set():
|
||||
break
|
||||
|
||||
return StreamingResponse(stream(), media_type=_SSE_CONTENT_TYPE)
|
||||
@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Any
|
||||
|
||||
from server.models import (
|
||||
WorkflowCopyRequest,
|
||||
@ -67,6 +68,91 @@ async def list_workflows():
|
||||
return {"workflows": [file.name for file in YAML_DIR.glob("*.yaml")]}
|
||||
|
||||
|
||||
@router.get("/api/workflows/{filename}/args")
|
||||
async def get_workflow_args(filename: str):
|
||||
print(str)
|
||||
try:
|
||||
safe_filename = validate_workflow_filename(filename, require_yaml_extension=True)
|
||||
print(safe_filename)
|
||||
file_path = YAML_DIR / safe_filename
|
||||
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
raise ResourceNotFoundError(
|
||||
"Workflow file not found",
|
||||
resource_type="workflow",
|
||||
resource_id=safe_filename,
|
||||
)
|
||||
|
||||
# Load and validate YAML content
|
||||
raw_content = file_path.read_text(encoding="utf-8")
|
||||
_, yaml_content = validate_workflow_content(safe_filename, raw_content)
|
||||
|
||||
args: list[dict[str, Any]] = []
|
||||
if isinstance(yaml_content, dict):
|
||||
graph = yaml_content.get("graph") or {}
|
||||
if isinstance(graph, dict):
|
||||
raw_args = graph.get("args") or []
|
||||
if isinstance(raw_args, list):
|
||||
if len(raw_args) == 0:
|
||||
raise ResourceNotFoundError(
|
||||
"Workflow file does not have args",
|
||||
resource_type="workflow",
|
||||
resource_id=safe_filename,
|
||||
)
|
||||
for item in raw_args:
|
||||
# Each item is expected to be like: { arg_name: [ {key: value}, ... ] }
|
||||
if not isinstance(item, dict) or len(item) != 1:
|
||||
continue
|
||||
(arg_name, spec_list), = item.items()
|
||||
if not isinstance(arg_name, str):
|
||||
continue
|
||||
|
||||
arg_info: dict[str, Any] = {"name": arg_name}
|
||||
if isinstance(spec_list, list):
|
||||
for spec in spec_list:
|
||||
if isinstance(spec, dict):
|
||||
for key, value in spec.items():
|
||||
# Later entries override earlier ones if duplicated
|
||||
arg_info[str(key)] = value
|
||||
args.append(arg_info)
|
||||
|
||||
logger = get_server_logger()
|
||||
logger.info(
|
||||
"Workflow args retrieved",
|
||||
log_type=LogType.WORKFLOW,
|
||||
filename=safe_filename,
|
||||
args_count=len(args),
|
||||
)
|
||||
|
||||
return {"args": args}
|
||||
except ValidationError as exc:
|
||||
# 参数或文件名等校验错误
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(exc)},
|
||||
)
|
||||
except SecurityError as exc:
|
||||
# 安全相关错误(例如路径遍历)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(exc)},
|
||||
)
|
||||
except ResourceNotFoundError as exc:
|
||||
# 文件不存在
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"message": str(exc)},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger = get_server_logger()
|
||||
logger.log_exception(exc, f"Unexpected error retrieving workflow args: {filename}")
|
||||
# 兜底错误
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": f"Failed to retrieve workflow args: {exc}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/workflows/upload/content")
|
||||
async def upload_workflow_content(request: WorkflowUploadContentRequest):
|
||||
return _persist_workflow_from_content(
|
||||
@ -78,7 +164,7 @@ async def upload_workflow_content(request: WorkflowUploadContentRequest):
|
||||
)
|
||||
|
||||
|
||||
@router.put("/api/workflows/{filename}")
|
||||
@router.put("/api/workflows/{filename}/update")
|
||||
async def update_workflow_content(filename: str, request: WorkflowUpdateContentRequest):
|
||||
return _persist_workflow_from_content(
|
||||
filename,
|
||||
@ -89,7 +175,7 @@ async def update_workflow_content(filename: str, request: WorkflowUpdateContentR
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/api/workflows/{filename}")
|
||||
@router.delete("/api/workflows/{filename}/delete")
|
||||
async def delete_workflow(filename: str):
|
||||
try:
|
||||
safe_filename = validate_workflow_filename(filename, require_yaml_extension=True)
|
||||
@ -180,7 +266,7 @@ async def copy_workflow_file(filename: str, request: WorkflowCopyRequest):
|
||||
raise WorkflowExecutionError(f"Failed to copy workflow: {exc}")
|
||||
|
||||
|
||||
@router.get("/api/workflows/{filename}")
|
||||
@router.get("/api/workflows/{filename}/get")
|
||||
async def get_workflow_raw_content(filename: str):
|
||||
try:
|
||||
safe_filename = validate_workflow_filename(filename, require_yaml_extension=True)
|
||||
@ -208,4 +294,4 @@ async def get_workflow_raw_content(filename: str):
|
||||
except Exception as exc:
|
||||
logger = get_server_logger()
|
||||
logger.log_exception(exc, f"Unexpected error retrieving workflow: {filename}")
|
||||
raise WorkflowExecutionError(f"Failed to retrieve workflow: {exc}")
|
||||
raise WorkflowExecutionError(f"Failed to retrieve workflow: {exc}")
|
||||
@ -49,8 +49,6 @@ class WebSocketManager:
|
||||
):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.connection_timestamps: Dict[str, float] = {}
|
||||
self.send_locks: Dict[str, asyncio.Lock] = {}
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
self.session_store = session_store or WorkflowSessionStore()
|
||||
self.session_controller = session_controller or SessionExecutionController(self.session_store)
|
||||
self.attachment_service = attachment_service or AttachmentService()
|
||||
@ -67,16 +65,10 @@ class WebSocketManager:
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str:
|
||||
await websocket.accept()
|
||||
if self.loop is None:
|
||||
try:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self.loop = None
|
||||
if not session_id:
|
||||
session_id = str(uuid.uuid4())
|
||||
self.active_connections[session_id] = websocket
|
||||
self.connection_timestamps[session_id] = time.time()
|
||||
self.send_locks[session_id] = asyncio.Lock()
|
||||
logging.info("WebSocket connected: %s", session_id)
|
||||
await self.send_message(
|
||||
session_id,
|
||||
@ -98,8 +90,6 @@ class WebSocketManager:
|
||||
del self.active_connections[session_id]
|
||||
if session_id in self.connection_timestamps:
|
||||
del self.connection_timestamps[session_id]
|
||||
if session_id in self.send_locks:
|
||||
del self.send_locks[session_id]
|
||||
self.session_controller.cleanup_session(session_id)
|
||||
remaining_session = self.session_store.get_session(session_id)
|
||||
if remaining_session and remaining_session.executor is None:
|
||||
@ -111,12 +101,7 @@ class WebSocketManager:
|
||||
if session_id in self.active_connections:
|
||||
websocket = self.active_connections[session_id]
|
||||
try:
|
||||
lock = self.send_locks.get(session_id)
|
||||
if lock is None:
|
||||
await websocket.send_text(_encode_ws_message(message))
|
||||
else:
|
||||
async with lock:
|
||||
await websocket.send_text(_encode_ws_message(message))
|
||||
await websocket.send_text(_encode_ws_message(message))
|
||||
except Exception as exc:
|
||||
traceback.print_exc()
|
||||
logging.error("Failed to send message to %s: %s", session_id, exc)
|
||||
@ -130,13 +115,7 @@ class WebSocketManager:
|
||||
else:
|
||||
asyncio.run(self.send_message(session_id, message))
|
||||
except RuntimeError:
|
||||
if self.loop and self.loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.send_message(session_id, message),
|
||||
self.loop,
|
||||
)
|
||||
else:
|
||||
asyncio.run(self.send_message(session_id, message))
|
||||
asyncio.run(self.send_message(session_id, message))
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user