202 lines
9.0 KiB
Python

"""Unified extensions configuration for MCP servers and skills."""
import json
import os
from pathlib import Path
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(
default="client_credentials",
description="OAuth grant type",
)
client_id: str | None = Field(default=None, description="OAuth client ID")
client_secret: str | None = Field(default=None, description="OAuth client secret")
refresh_token: str | None = Field(default=None, description="OAuth refresh token (for refresh_token grant)")
scope: str | None = Field(default=None, description="OAuth scope")
audience: str | None = Field(default=None, description="OAuth audience (provider-specific)")
token_field: str = Field(default="access_token", description="Field name containing access token in token response")
token_type_field: str = Field(default="token_type", description="Field name containing token type in token response")
expires_in_field: str = Field(default="expires_in", description="Field name containing expiry (seconds) in token response")
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
class McpServerConfig(BaseModel):
"""Configuration for a single MCP server."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
args: list[str] = Field(default_factory=list, description="Arguments to pass to the command (for stdio type)")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=True, description="Whether this skill is enabled")
class ExtensionsConfig(BaseModel):
"""Unified configuration for MCP servers and skills."""
mcp_servers: dict[str, McpServerConfig] = Field(
default_factory=dict,
description="Map of MCP server name to configuration",
alias="mcpServers",
)
skills: dict[str, SkillStateConfig] = Field(
default_factory=dict,
description="Map of skill name to state configuration",
)
model_config = ConfigDict(extra="allow", frozen=True, populate_by_name=True)
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
"""Resolve the extensions config file path.
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
3. Otherwise, check for `extensions_config.json` in the current directory, then in the parent directory.
4. For backward compatibility, also check for `mcp_config.json` if `extensions_config.json` is not found.
5. If not found, return None (extensions are optional).
Args:
config_path: Optional path to extensions config file.
Resolution order:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
3. Otherwise, search backend/repository-root defaults for
`extensions_config.json`, then legacy `mcp_config.json`.
Returns:
Path to the extensions config file if found, otherwise None.
"""
if config_path:
path = Path(config_path)
if not path.exists():
raise FileNotFoundError(f"Extensions config file specified by param `config_path` not found at {path}")
return path
elif os.getenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH"):
path = Path(os.getenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH"))
if not path.exists():
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
return path
else:
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
for path in (
backend_dir / "extensions_config.json",
repo_root / "extensions_config.json",
backend_dir / "mcp_config.json",
repo_root / "mcp_config.json",
):
if path.exists():
return path
# Extensions are optional, so return None if not found
return None
@classmethod
def from_file(cls, config_path: str | None = None) -> "ExtensionsConfig":
"""Load extensions config from JSON file.
See `resolve_config_path` for more details.
Args:
config_path: Path to the extensions config file.
Returns:
ExtensionsConfig: The loaded config, or empty config if file not found.
"""
resolved_path = cls.resolve_config_path(config_path)
if resolved_path is None:
# Return empty config if extensions config file is not found
return cls(mcp_servers={}, skills={})
try:
with open(resolved_path, encoding="utf-8") as f:
config_data = json.load(f)
cls.resolve_env_variables(config_data)
return cls.model_validate(config_data)
except json.JSONDecodeError as e:
raise ValueError(f"Extensions config file at {resolved_path} is not valid JSON: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to load extensions config from {resolved_path}: {e}") from e
@classmethod
def resolve_env_variables(cls, config: dict[str, Any]) -> dict[str, Any]:
"""Recursively resolve environment variables in the config.
Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY
Args:
config: The config to resolve environment variables in.
Returns:
The config with environment variables resolved.
"""
for key, value in config.items():
if isinstance(value, str):
if value.startswith("$"):
env_value = os.getenv(value[1:])
if env_value is None:
# Unresolved placeholder — store empty string so downstream
# consumers (e.g. MCP servers) don't receive the literal "$VAR"
# token as an actual environment value.
config[key] = ""
else:
config[key] = env_value
else:
config[key] = value
elif isinstance(value, dict):
config[key] = cls.resolve_env_variables(value)
elif isinstance(value, list):
config[key] = [cls.resolve_env_variables(item) if isinstance(item, dict) else item for item in value]
return config
def get_enabled_mcp_servers(self) -> dict[str, McpServerConfig]:
"""Get only the enabled MCP servers.
Returns:
Dictionary of enabled MCP servers.
"""
return {name: config for name, config in self.mcp_servers.items() if config.enabled}
def is_skill_enabled(self, skill_name: str, skill_category: str) -> bool:
"""Check if a skill is enabled.
Args:
skill_name: Name of the skill
skill_category: Category of the skill
Returns:
True if enabled, False otherwise
"""
skill_config = self.skills.get(skill_name)
if skill_config is None:
# Default to enable for public & custom skill
return skill_category in ("public", "custom")
return skill_config.enabled