refactor(config): eliminate global mutable state, wire DeerFlowContext into runtime

This commit is contained in:
greatmengqi 2026-04-15 21:28:15 +08:00
parent 7b9d224b3a
commit 9040e49e4a
111 changed files with 4847 additions and 4504 deletions

View File

@ -179,7 +179,9 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`.
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
**Config Lifecycle**: All config models are `frozen=True` (immutable after construction). `AppConfig.from_file()` is a pure function — no side effects on sub-module globals. `get_app_config()` is backed by a single `ContextVar`, set once via `init_app_config()` at process startup. To update config at runtime (e.g., Gateway API updates MCP/Skills), construct a new `AppConfig.from_file()` and call `init_app_config()` again. No mtime detection, no auto-reload.
**DeerFlowContext**: Per-invocation typed context for the agent execution path, injected via LangGraph `Runtime[DeerFlowContext]`. Holds `app_config: AppConfig`, `thread_id: str`, `agent_name: str | None`. Gateway runtime and `DeerFlowClient` construct full `DeerFlowContext` at invoke time; LangGraph Server path uses a fallback via `resolve_context()`. Middleware and tools access context through `resolve_context(runtime)` which returns a typed `DeerFlowContext` regardless of entry point. Mutable runtime state (`sandbox_id`) flows through `ThreadState.sandbox`, not context.
Configuration priority:
1. Explicit `config_path` argument

View File

@ -66,9 +66,9 @@ class ChannelService:
@classmethod
def from_app_config(cls) -> ChannelService:
"""Create a ChannelService from the application config."""
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
config = get_app_config()
config = AppConfig.current()
channels_config = {}
# extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {}

View File

@ -27,7 +27,7 @@ from app.gateway.routers import (
threads,
uploads,
)
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
# Configure logging
logging.basicConfig(
@ -147,7 +147,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Load config and check necessary environment variables at startup
try:
get_app_config()
AppConfig.current()
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"

View File

@ -6,7 +6,8 @@ from typing import Literal
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"])
@ -90,9 +91,9 @@ async def get_mcp_configuration() -> McpConfigResponse:
}
```
"""
config = get_extensions_config()
ext = AppConfig.current().extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in ext.mcp_servers.items()})
@router.put(
@ -143,12 +144,12 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration
current_config = get_extensions_config()
current_ext = AppConfig.current().extensions
# Convert request to dict format for JSON serialization
config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
}
# Write the configuration to file
@ -161,8 +162,9 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
# will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache
reloaded_config = reload_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
AppConfig.init(AppConfig.from_file())
reloaded_ext = AppConfig.current().extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_ext.mcp_servers.items()})
except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)

View File

@ -12,7 +12,7 @@ from deerflow.agents.memory.updater import (
reload_memory_data,
update_memory_fact,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.user_context import get_effective_user_id
router = APIRouter(prefix="/api", tags=["memory"])
@ -314,7 +314,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
}
```
"""
config = get_memory_config()
config = AppConfig.current().memory
return MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
@ -339,7 +339,7 @@ async def get_memory_status() -> MemoryStatusResponse:
Returns:
Combined memory configuration and current data.
"""
config = get_memory_config()
config = AppConfig.current().memory
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryStatusResponse(

View File

@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"])
@ -58,7 +58,7 @@ async def list_models() -> ModelsListResponse:
}
```
"""
config = get_app_config()
config = AppConfig.current()
models = [
ModelResponse(
name=model.name,
@ -101,7 +101,7 @@ async def get_model(model_name: str) -> ModelResponse:
}
```
"""
config = get_app_config()
config = AppConfig.current()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")

View File

@ -8,7 +8,8 @@ from pydantic import BaseModel, Field
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
@ -325,19 +326,19 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
extensions_config = get_extensions_config()
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
ext = AppConfig.current().extensions
ext.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
config_data = {
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()},
"mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()},
}
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}")
reload_extensions_config()
AppConfig.init(AppConfig.from_file())
await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False)

View File

@ -3,6 +3,7 @@ import logging
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
@ -16,8 +17,8 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config
from deerflow.config.app_config import get_app_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
def _resolve_model_name(requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
app_config = AppConfig.current()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@ -40,7 +41,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
def _create_summarization_middleware() -> SummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
config = AppConfig.current().summarization
if not config.enabled:
return None
@ -232,7 +233,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled:
if AppConfig.current().token_usage.enabled:
middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware
@ -243,7 +244,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
app_config = AppConfig.current()
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
@ -272,7 +273,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
return middlewares
def make_lead_agent(config: RunnableConfig):
def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
@ -295,7 +296,7 @@ def make_lead_agent(config: RunnableConfig):
# Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name)
app_config = get_app_config()
app_config = AppConfig.current()
model_config = app_config.get_model_config(model_name)
if model_config is None:
@ -338,6 +339,7 @@ def make_lead_agent(config: RunnableConfig):
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
state_schema=ThreadState,
context_schema=DeerFlowContext,
)
# Default lead agent (unchanged behavior)
@ -349,4 +351,5 @@ def make_lead_agent(config: RunnableConfig):
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
),
state_schema=ThreadState,
context_schema=DeerFlowContext,
)

View File

@ -5,6 +5,7 @@ from datetime import datetime
from functools import lru_cache
from deerflow.config.agents_config import load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.skills import load_skills
from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names
@ -518,10 +519,9 @@ def _get_memory_context(agent_name: str | None = None) -> str:
"""
try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import get_effective_user_id
config = get_memory_config()
config = AppConfig.current().memory
if not config.enabled or not config.injection_enabled:
return ""
@ -577,9 +577,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
skills = _get_enabled_skills()
try:
from deerflow.config import get_app_config
config = get_app_config()
config = AppConfig.current()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
@ -618,9 +616,7 @@ def get_deferred_tools_prompt_section() -> str:
from deerflow.tools.builtins.tool_search import get_deferred_registry
try:
from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
if not AppConfig.current().tool_search.enabled:
return ""
except Exception:
return ""
@ -636,9 +632,7 @@ def get_deferred_tools_prompt_section() -> str:
def _build_acp_section() -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured."""
try:
from deerflow.config.acp_config import get_acp_agents
agents = get_acp_agents()
agents = AppConfig.current().acp_agents
if not agents:
return ""
except Exception:
@ -656,9 +650,7 @@ def _build_acp_section() -> str:
def _build_custom_mounts_section() -> str:
"""Build a prompt section for explicitly configured sandbox mounts."""
try:
from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
mounts = AppConfig.current().sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""

View File

@ -7,7 +7,7 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@ -61,7 +61,7 @@ class MemoryUpdateQueue:
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
config = AppConfig.current().memory
if not config.enabled:
return
@ -93,7 +93,7 @@ class MemoryUpdateQueue:
def _reset_timer(self) -> None:
"""Reset the debounce timer."""
config = get_memory_config()
config = AppConfig.current().memory
# Cancel existing timer if any
if self._timer is not None:

View File

@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
@ -84,7 +84,7 @@ class FileMemoryStorage(MemoryStorage):
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().user_agent_memory_file(user_id, agent_name)
config = get_memory_config()
config = AppConfig.current().memory
if config.storage_path and Path(config.storage_path).is_absolute():
return Path(config.storage_path)
return get_paths().user_memory_file(user_id)
@ -92,7 +92,7 @@ class FileMemoryStorage(MemoryStorage):
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
config = AppConfig.current().memory
if config.storage_path:
p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p
@ -188,7 +188,7 @@ def get_memory_storage() -> MemoryStorage:
if _storage_instance is not None:
return _storage_instance
config = get_memory_config()
config = AppConfig.current().memory
storage_class_path = config.storage_class
try:

View File

@ -16,7 +16,7 @@ from deerflow.agents.memory.storage import (
get_memory_storage,
utc_now_iso_z,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@ -270,7 +270,7 @@ class MemoryUpdater:
def _get_model(self):
"""Get the model for memory updates."""
config = get_memory_config()
config = AppConfig.current().memory
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
@ -296,7 +296,7 @@ class MemoryUpdater:
Returns:
True if update was successful, False otherwise.
"""
config = get_memory_config()
config = AppConfig.current().memory
if not config.enabled:
return False
@ -385,7 +385,7 @@ class MemoryUpdater:
Returns:
Updated memory data.
"""
config = get_memory_config()
config = AppConfig.current().memory
now = utc_now_iso_z()
# Update user sections

View File

@ -24,6 +24,8 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
@ -159,12 +161,9 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
def _get_thread_id(self, runtime: Runtime[DeerFlowContext]) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
return "default"
return runtime.context.thread_id or "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
@ -288,11 +287,11 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
def after_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
async def aafter_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:

View File

@ -6,11 +6,10 @@ from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@ -194,7 +193,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
self._agent_name = agent_name
@override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Queue conversation for memory update after agent completes.
Args:
@ -204,15 +203,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns:
None (no state changes needed from this middleware).
"""
config = get_memory_config()
if not config.enabled:
memory_config = runtime.context.app_config.memory
if not memory_config.enabled:
return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
thread_id = runtime.context.thread_id
if not thread_id:
logger.debug("No thread_id in context, skipping memory update")
return None

View File

@ -3,10 +3,10 @@ from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
@ -77,14 +77,10 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
return self._get_thread_paths(thread_id, user_id=user_id)
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
context = runtime.context or {}
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
thread_id = runtime.context.thread_id
if thread_id is None:
if not thread_id:
raise ValueError("Thread ID is required in runtime context or config.configurable")
user_id = get_effective_user_id()

View File

@ -8,7 +8,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@ -46,7 +46,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = get_title_config()
config = AppConfig.current().title
if not config.enabled:
return False
@ -71,7 +71,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config()
config = AppConfig.current().title
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
@ -89,13 +89,13 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
config = AppConfig.current().title
title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str:
config = get_title_config()
config = AppConfig.current().title
fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
@ -128,7 +128,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
if not self._should_generate_title(state):
return None
config = get_title_config()
config = AppConfig.current().title
prompt, user_msg = self._build_title_prompt(state)
try:

View File

@ -94,9 +94,9 @@ def _build_runtime_middlewares(
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
from deerflow.config.app_config import AppConfig
guardrails_config = get_guardrails_config()
guardrails_config = AppConfig.current().guardrails
if guardrails_config.enabled and guardrails_config.provider:
import inspect

View File

@ -9,6 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.file_conversion import extract_outline
@ -185,7 +186,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return files if files else None
@override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files.
@ -214,14 +215,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return None
# Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
thread_id = runtime.context.thread_id
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files

View File

@ -36,8 +36,9 @@ from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import get_app_config, reload_app_config
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model
from deerflow.runtime.user_context import get_effective_user_id
@ -142,8 +143,8 @@ class DeerFlowClient:
middlewares: Optional list of custom middlewares to inject into the agent.
"""
if config_path is not None:
reload_app_config(config_path)
self._app_config = get_app_config()
AppConfig.init(AppConfig.from_file(config_path))
self._app_config = AppConfig.current()
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@ -552,9 +553,7 @@ class DeerFlowClient:
self._ensure_agent(config)
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
context = {"thread_id": thread_id}
if self._agent_name:
context["agent_name"] = self._agent_name
context = DeerFlowContext(app_config=self._app_config, thread_id=thread_id, agent_name=self._agent_name)
seen_ids: set[str] = set()
# Cross-mode handoff: ids already streamed via LangGraph ``messages``
@ -817,8 +816,8 @@ class DeerFlowClient:
Dict with "mcp_servers" key mapping server name to config,
matching the Gateway API ``McpConfigResponse`` schema.
"""
config = get_extensions_config()
return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}}
ext = AppConfig.current().extensions
return {"mcp_servers": {name: server.model_dump() for name, server in ext.mcp_servers.items()}}
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
"""Update MCP server configurations.
@ -840,18 +839,19 @@ class DeerFlowClient:
if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
current_config = get_extensions_config()
current_ext = AppConfig.current().extensions
config_data = {
"mcpServers": mcp_servers,
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
}
self._atomic_write_json(config_path, config_data)
self._agent = None
self._agent_config_key = None
reloaded = reload_extensions_config()
AppConfig.init(AppConfig.from_file())
reloaded = AppConfig.current().extensions
return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
# ------------------------------------------------------------------
@ -905,19 +905,19 @@ class DeerFlowClient:
if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
extensions_config = get_extensions_config()
extensions_config.skills[name] = SkillStateConfig(enabled=enabled)
ext = AppConfig.current().extensions
ext.skills[name] = SkillStateConfig(enabled=enabled)
config_data = {
"mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()},
"mcpServers": {n: s.model_dump() for n, s in ext.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()},
}
self._atomic_write_json(config_path, config_data)
self._agent = None
self._agent_config_key = None
reload_extensions_config()
AppConfig.init(AppConfig.from_file())
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
if updated is None:
@ -1000,9 +1000,7 @@ class DeerFlowClient:
Returns:
Memory config dict.
"""
from deerflow.config.memory_config import get_memory_config
config = get_memory_config()
config = AppConfig.current().memory
return {
"enabled": config.enabled,
"storage_path": config.storage_path,

View File

@ -25,7 +25,7 @@ except ImportError: # pragma: no cover - Windows fallback
fcntl = None # type: ignore[assignment]
import msvcrt
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox import Sandbox
@ -149,7 +149,7 @@ class AioSandboxProvider(SandboxProvider):
def _load_config(self) -> dict:
"""Load sandbox configuration from app config."""
config = get_app_config()
config = AppConfig.current()
sandbox_config = config.sandbox
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
@ -281,7 +281,7 @@ class AioSandboxProvider(SandboxProvider):
so the host Docker daemon can resolve the path.
"""
try:
config = get_app_config()
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path

View File

@ -7,7 +7,7 @@ import logging
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@ -63,7 +63,7 @@ def web_search_tool(
query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of results to return. Default is 5.
"""
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:

View File

@ -3,11 +3,11 @@ import json
from exa_py import Exa
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
def _get_exa_client(tool_name: str = "web_search") -> Exa:
config = get_app_config().get_tool_config(tool_name)
config = AppConfig.current().get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
@ -22,7 +22,7 @@ def web_search_tool(query: str) -> str:
query: The query to search for.
"""
try:
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
max_results = 5
search_type = "auto"
contents_max_characters = 1000

View File

@ -3,11 +3,11 @@ import json
from firecrawl import FirecrawlApp
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
config = get_app_config().get_tool_config(tool_name)
config = AppConfig.current().get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
@ -22,7 +22,7 @@ def web_search_tool(query: str) -> str:
query: The query to search for.
"""
try:
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
max_results = 5
if config is not None:
max_results = config.model_extra.get("max_results", max_results)

View File

@ -7,7 +7,7 @@ import logging
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@ -99,7 +99,7 @@ def image_search_tool(
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
"""
config = get_app_config().get_tool_config("image_search")
config = AppConfig.current().get_tool_config("image_search")
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:

View File

@ -1,6 +1,6 @@
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.utils.readability import ReadabilityExtractor
from .infoquest_client import InfoQuestClient
@ -9,12 +9,12 @@ readability_extractor = ReadabilityExtractor()
def _get_infoquest_client() -> InfoQuestClient:
search_config = get_app_config().get_tool_config("web_search")
search_config = AppConfig.current().get_tool_config("web_search")
search_time_range = -1
if search_config is not None and "search_time_range" in search_config.model_extra:
search_time_range = search_config.model_extra.get("search_time_range")
fetch_config = get_app_config().get_tool_config("web_fetch")
fetch_config = AppConfig.current().get_tool_config("web_fetch")
fetch_time = -1
if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
fetch_time = fetch_config.model_extra.get("fetch_time")
@ -25,7 +25,7 @@ def _get_infoquest_client() -> InfoQuestClient:
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
image_search_config = get_app_config().get_tool_config("image_search")
image_search_config = AppConfig.current().get_tool_config("image_search")
image_search_time_range = -1
if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra:
image_search_time_range = image_search_config.model_extra.get("image_search_time_range")

View File

@ -1,7 +1,7 @@
from langchain.tools import tool
from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.utils.readability import ReadabilityExtractor
readability_extractor = ReadabilityExtractor()
@ -20,7 +20,7 @@ async def web_fetch_tool(url: str) -> str:
"""
jina_client = JinaClient()
timeout = 10
config = get_app_config().get_tool_config("web_fetch")
config = AppConfig.current().get_tool_config("web_fetch")
if config is not None and "timeout" in config.model_extra:
timeout = config.model_extra.get("timeout")
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)

View File

@ -3,11 +3,11 @@ import json
from langchain.tools import tool
from tavily import TavilyClient
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
def _get_tavily_client() -> TavilyClient:
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
@ -21,7 +21,7 @@ def web_search_tool(query: str) -> str:
Args:
query: The query to search for.
"""
config = get_app_config().get_tool_config("web_search")
config = AppConfig.current().get_tool_config("web_search")
max_results = 5
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results")

View File

@ -1,6 +1,6 @@
from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config
from .memory_config import MemoryConfig, get_memory_config
from .app_config import AppConfig
from .extensions_config import ExtensionsConfig
from .memory_config import MemoryConfig
from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig
from .skills_config import SkillsConfig
@ -13,18 +13,16 @@ from .tracing_config import (
)
__all__ = [
"get_app_config",
"SkillEvolutionConfig",
"Paths",
"get_paths",
"SkillsConfig",
"AppConfig",
"ExtensionsConfig",
"get_extensions_config",
"MemoryConfig",
"get_memory_config",
"get_tracing_config",
"get_explicitly_enabled_tracing_providers",
"Paths",
"SkillEvolutionConfig",
"SkillsConfig",
"get_enabled_tracing_providers",
"get_explicitly_enabled_tracing_providers",
"get_paths",
"get_tracing_config",
"is_tracing_enabled",
"validate_enabled_tracing_providers",
]

View File

@ -1,16 +1,13 @@
"""ACP (Agent Client Protocol) agent configuration loaded from config.yaml."""
import logging
from collections.abc import Mapping
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from pydantic import BaseModel, ConfigDict, Field
class ACPAgentConfig(BaseModel):
"""Configuration for a single ACP-compatible agent."""
model_config = ConfigDict(frozen=True)
command: str = Field(description="Command to launch the ACP agent subprocess")
args: list[str] = Field(default_factory=list, description="Additional command arguments")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.")
@ -24,28 +21,3 @@ class ACPAgentConfig(BaseModel):
"are denied — the agent must be configured to operate without requesting permissions."
),
)
_acp_agents: dict[str, ACPAgentConfig] = {}
def get_acp_agents() -> dict[str, ACPAgentConfig]:
"""Get the currently configured ACP agents.
Returns:
Mapping of agent name -> ACPAgentConfig. Empty dict if no ACP agents are configured.
"""
return _acp_agents
def load_acp_config_from_dict(config_dict: Mapping[str, Mapping[str, object]] | None) -> None:
"""Load ACP agent configuration from a dictionary (typically from config.yaml).
Args:
config_dict: Mapping of agent name -> config fields.
"""
global _acp_agents
if config_dict is None:
config_dict = {}
_acp_agents = {name: ACPAgentConfig(**cfg) for name, cfg in config_dict.items()}
logger.info("ACP config loaded: %d agent(s): %s", len(_acp_agents), list(_acp_agents.keys()))

View File

@ -5,7 +5,7 @@ import re
from typing import Any
import yaml
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from deerflow.config.paths import get_paths
@ -18,6 +18,8 @@ AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
class AgentConfig(BaseModel):
"""Configuration for a custom agent."""
model_config = ConfigDict(frozen=True)
name: str
description: str = ""
model: str | None = None

View File

@ -1,31 +1,33 @@
from __future__ import annotations
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
from typing import Any, ClassVar, Self
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
from deerflow.config.stream_bridge_config import StreamBridgeConfig
from deerflow.config.subagents_config import SubagentsAppConfig
from deerflow.config.summarization_config import SummarizationConfig
from deerflow.config.title_config import TitleConfig
from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
from deerflow.config.tool_search_config import ToolSearchConfig
load_dotenv()
@ -57,11 +59,12 @@ class AppConfig(BaseModel):
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
model_config = ConfigDict(extra="allow", frozen=False)
model_config = ConfigDict(extra="allow", frozen=True)
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP agent configurations keyed by agent name")
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path:
@ -109,41 +112,6 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump()
@ -254,130 +222,26 @@ class AppConfig(BaseModel):
"""
return next((group for group in self.tool_groups if group.name == name), None)
# -- Lifecycle (class-level singleton via ContextVar) --
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
_current: ClassVar[ContextVar[AppConfig]] = ContextVar("deerflow_app_config")
@classmethod
def init(cls, config: AppConfig) -> None:
"""Set the AppConfig for the current context. Call once at process startup."""
cls._current.set(config)
def _get_config_mtime(config_path: Path) -> float | None:
"""Get the modification time of a config file if it exists."""
try:
return config_path.stat().st_mtime
except OSError:
return None
@classmethod
def current(cls) -> AppConfig:
"""Get the current AppConfig.
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
"""Load config from disk and refresh cache metadata."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
resolved_path = AppConfig.resolve_config_path(config_path)
_app_config = AppConfig.from_file(str(resolved_path))
_app_config_path = resolved_path
_app_config_mtime = _get_config_mtime(resolved_path)
_app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Reload the config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded AppConfig instance.
"""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Reset the cached config instance.
This clears the singleton cache, causing the next call to
`get_app_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Set a custom config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The AppConfig instance to use.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
Auto-initializes from config file on first access for backward compatibility.
Prefer calling AppConfig.init() explicitly at process startup.
"""
try:
return cls._current.get()
except LookupError:
logger.debug("AppConfig not initialized, auto-loading from file")
config = cls.from_file()
cls._current.set(config)
return config

View File

@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"]
@ -10,6 +10,8 @@ CheckpointerType = Literal["memory", "sqlite", "postgres"]
class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer."""
model_config = ConfigDict(frozen=True)
type: CheckpointerType = Field(
description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). "
@ -23,24 +25,3 @@ class CheckpointerConfig(BaseModel):
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
)
# Global configuration instance — None means no checkpointer is configured.
_checkpointer_config: CheckpointerConfig | None = None
def get_checkpointer_config() -> CheckpointerConfig | None:
"""Get the current checkpointer configuration, or None if not configured."""
return _checkpointer_config
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
"""Set the checkpointer configuration."""
global _checkpointer_config
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
_checkpointer_config = CheckpointerConfig(**config_dict)

View File

@ -0,0 +1,59 @@
"""Per-invocation context for DeerFlow agent execution.
Injected via LangGraph Runtime. Middleware and tools access this
via Runtime[DeerFlowContext] parameters, through resolve_context().
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class DeerFlowContext:
"""Typed, immutable, per-invocation context injected via LangGraph Runtime.
Fields are all known at run start and never change during execution.
Mutable runtime state (e.g. sandbox_id) flows through ThreadState, not here.
"""
app_config: Any # AppConfig — typed as Any to avoid circular import at module level
thread_id: str
agent_name: str | None = None
def resolve_context(runtime: Any) -> DeerFlowContext:
"""Extract or construct DeerFlowContext from runtime.
Gateway/Client paths: runtime.context is already DeerFlowContext return directly.
LangGraph Server / legacy dict path: construct from dict context or configurable fallback.
"""
ctx = getattr(runtime, "context", None)
if isinstance(ctx, DeerFlowContext):
return ctx
from deerflow.config.app_config import AppConfig
# Try dict context first (legacy path, tests), then configurable
if isinstance(ctx, dict):
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=ctx.get("thread_id", ""),
agent_name=ctx.get("agent_name"),
)
# No context at all — fall back to LangGraph configurable
try:
from langgraph.config import get_config
cfg = get_config().get("configurable", {})
except RuntimeError:
# Outside runnable context (e.g. unit tests)
cfg = {}
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=cfg.get("thread_id", ""),
agent_name=cfg.get("agent_name"),
)

View File

@ -11,6 +11,8 @@ from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(
@ -28,12 +30,13 @@ class McpOAuthConfig(BaseModel):
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
model_config = ConfigDict(extra="allow")
class McpServerConfig(BaseModel):
"""Configuration for a single MCP server."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
@ -43,12 +46,13 @@ class McpServerConfig(BaseModel):
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")
class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=True, description="Whether this skill is enabled")
@ -64,7 +68,7 @@ class ExtensionsConfig(BaseModel):
default_factory=dict,
description="Map of skill name to state configuration",
)
model_config = ConfigDict(extra="allow", populate_by_name=True)
model_config = ConfigDict(extra="allow", frozen=True, populate_by_name=True)
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
@ -195,62 +199,3 @@ class ExtensionsConfig(BaseModel):
# Default to enable for public & custom skill
return skill_category in ("public", "custom")
return skill_config.enabled
_extensions_config: ExtensionsConfig | None = None
def get_extensions_config() -> ExtensionsConfig:
"""Get the extensions config instance.
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
from file, or `reset_extensions_config()` to clear the cache.
Returns:
The cached ExtensionsConfig instance.
"""
global _extensions_config
if _extensions_config is None:
_extensions_config = ExtensionsConfig.from_file()
return _extensions_config
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
"""Reload the extensions config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to extensions config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded ExtensionsConfig instance.
"""
global _extensions_config
_extensions_config = ExtensionsConfig.from_file(config_path)
return _extensions_config
def reset_extensions_config() -> None:
"""Reset the cached extensions config instance.
This clears the singleton cache, causing the next call to
`get_extensions_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _extensions_config
_extensions_config = None
def set_extensions_config(config: ExtensionsConfig) -> None:
"""Set a custom extensions config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The ExtensionsConfig instance to use.
"""
global _extensions_config
_extensions_config = config

View File

@ -1,11 +1,13 @@
"""Configuration for pre-tool-call authorization."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class GuardrailProviderConfig(BaseModel):
"""Configuration for a guardrail provider."""
model_config = ConfigDict(frozen=True)
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
@ -18,31 +20,9 @@ class GuardrailsConfig(BaseModel):
agent's passport reference, and returns an allow/deny decision.
"""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable guardrail middleware")
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
_guardrails_config: GuardrailsConfig | None = None
def get_guardrails_config() -> GuardrailsConfig:
"""Get the guardrails config, returning defaults if not loaded."""
global _guardrails_config
if _guardrails_config is None:
_guardrails_config = GuardrailsConfig()
return _guardrails_config
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
"""Load guardrails config from a dict (called during AppConfig loading)."""
global _guardrails_config
_guardrails_config = GuardrailsConfig.model_validate(data)
return _guardrails_config
def reset_guardrails_config() -> None:
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
global _guardrails_config
_guardrails_config = None

View File

@ -1,11 +1,13 @@
"""Configuration for memory mechanism."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class MemoryConfig(BaseModel):
"""Configuration for global memory mechanism."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=True,
description="Whether to enable memory mechanism",
@ -60,24 +62,3 @@ class MemoryConfig(BaseModel):
le=8000,
description="Maximum tokens to use for memory injection",
)
# Global configuration instance
_memory_config: MemoryConfig = MemoryConfig()
def get_memory_config() -> MemoryConfig:
"""Get the current memory configuration."""
return _memory_config
def set_memory_config(config: MemoryConfig) -> None:
"""Set the memory configuration."""
global _memory_config
_memory_config = config
def load_memory_config_from_dict(config_dict: dict) -> None:
"""Load memory configuration from a dictionary."""
global _memory_config
_memory_config = MemoryConfig(**config_dict)

View File

@ -12,7 +12,7 @@ class ModelConfig(BaseModel):
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
)
model: str = Field(..., description="Model name")
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
use_responses_api: bool | None = Field(
default=None,
description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API",

View File

@ -4,6 +4,8 @@ from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount."""
model_config = ConfigDict(frozen=True)
host_path: str = Field(..., description="Path on the host machine")
container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only")
@ -80,4 +82,4 @@ class SandboxConfig(BaseModel):
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
)
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)

View File

@ -1,9 +1,11 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class SkillEvolutionConfig(BaseModel):
"""Configuration for agent-managed skill evolution."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Whether the agent can create and modify skills under skills/custom.",

View File

@ -1,6 +1,6 @@
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
def _default_repo_root() -> Path:
@ -11,6 +11,8 @@ def _default_repo_root() -> Path:
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
model_config = ConfigDict(frozen=True)
path: str | None = Field(
default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",

View File

@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
StreamBridgeType = Literal["memory", "redis"]
@ -10,6 +10,8 @@ StreamBridgeType = Literal["memory", "redis"]
class StreamBridgeConfig(BaseModel):
"""Configuration for the stream bridge that connects agent workers to SSE endpoints."""
model_config = ConfigDict(frozen=True)
type: StreamBridgeType = Field(
default="memory",
description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).",
@ -22,25 +24,3 @@ class StreamBridgeConfig(BaseModel):
default=256,
description="Maximum number of events buffered per run in the memory bridge.",
)
# Global configuration instance — None means no stream bridge is configured
# (falls back to memory with defaults).
_stream_bridge_config: StreamBridgeConfig | None = None
def get_stream_bridge_config() -> StreamBridgeConfig | None:
"""Get the current stream bridge configuration, or None if not configured."""
return _stream_bridge_config
def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
"""Set the stream bridge configuration."""
global _stream_bridge_config
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
_stream_bridge_config = StreamBridgeConfig(**config_dict)

View File

@ -1,15 +1,13 @@
"""Configuration for the subagent system loaded from config.yaml."""
import logging
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from pydantic import BaseModel, ConfigDict, Field
class SubagentOverrideConfig(BaseModel):
"""Per-agent configuration overrides."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int | None = Field(
default=None,
ge=1,
@ -25,6 +23,8 @@ class SubagentOverrideConfig(BaseModel):
class SubagentsAppConfig(BaseModel):
"""Configuration for the subagent system."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int = Field(
default=900,
ge=1,
@ -62,41 +62,3 @@ class SubagentsAppConfig(BaseModel):
if self.max_turns is not None:
return self.max_turns
return builtin_default
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
def get_subagents_app_config() -> SubagentsAppConfig:
"""Get the current subagents configuration."""
return _subagents_config
def load_subagents_config_from_dict(config_dict: dict) -> None:
"""Load subagents configuration from a dictionary."""
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if parts:
overrides_summary[name] = ", ".join(parts)
if overrides_summary:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary,
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)

View File

@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
ContextSizeType = Literal["fraction", "tokens", "messages"]
@ -10,6 +10,8 @@ ContextSizeType = Literal["fraction", "tokens", "messages"]
class ContextSize(BaseModel):
"""Context size specification for trigger or keep parameters."""
model_config = ConfigDict(frozen=True)
type: ContextSizeType = Field(description="Type of context size specification")
value: int | float = Field(description="Value for the context size specification")
@ -21,6 +23,8 @@ class ContextSize(BaseModel):
class SummarizationConfig(BaseModel):
"""Configuration for automatic conversation summarization."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Whether to enable automatic conversation summarization",
@ -51,24 +55,3 @@ class SummarizationConfig(BaseModel):
default=None,
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
)
# Global configuration instance
_summarization_config: SummarizationConfig = SummarizationConfig()
def get_summarization_config() -> SummarizationConfig:
"""Get the current summarization configuration."""
return _summarization_config
def set_summarization_config(config: SummarizationConfig) -> None:
"""Set the summarization configuration."""
global _summarization_config
_summarization_config = config
def load_summarization_config_from_dict(config_dict: dict) -> None:
"""Load summarization configuration from a dictionary."""
global _summarization_config
_summarization_config = SummarizationConfig(**config_dict)

View File

@ -1,11 +1,13 @@
"""Configuration for automatic thread title generation."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class TitleConfig(BaseModel):
"""Configuration for automatic thread title generation."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=True,
description="Whether to enable automatic title generation",
@ -30,24 +32,3 @@ class TitleConfig(BaseModel):
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
description="Prompt template for title generation",
)
# Global configuration instance
_title_config: TitleConfig = TitleConfig()
def get_title_config() -> TitleConfig:
"""Get the current title configuration."""
return _title_config
def set_title_config(config: TitleConfig) -> None:
"""Set the title configuration."""
global _title_config
_title_config = config
def load_title_config_from_dict(config_dict: dict) -> None:
"""Load title configuration from a dictionary."""
global _title_config
_title_config = TitleConfig(**config_dict)

View File

@ -1,7 +1,9 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")

View File

@ -5,7 +5,7 @@ class ToolGroupConfig(BaseModel):
"""Config section for a tool group"""
name: str = Field(..., description="Unique name for the tool group")
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
class ToolConfig(BaseModel):
@ -17,4 +17,4 @@ class ToolConfig(BaseModel):
...,
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
)
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)

View File

@ -1,6 +1,6 @@
"""Configuration for deferred tool loading via tool_search."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class ToolSearchConfig(BaseModel):
@ -11,25 +11,9 @@ class ToolSearchConfig(BaseModel):
via the tool_search tool at runtime.
"""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Defer tools and enable tool_search",
)
_tool_search_config: ToolSearchConfig | None = None
def get_tool_search_config() -> ToolSearchConfig:
"""Get the tool search config, loading from AppConfig if needed."""
global _tool_search_config
if _tool_search_config is None:
_tool_search_config = ToolSearchConfig()
return _tool_search_config
def load_tool_search_config_from_dict(data: dict) -> ToolSearchConfig:
"""Load tool search config from a dict (called during AppConfig loading)."""
global _tool_search_config
_tool_search_config = ToolSearchConfig.model_validate(data)
return _tool_search_config

View File

@ -1,7 +1,7 @@
import os
import threading
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
_config_lock = threading.Lock()
@ -9,6 +9,8 @@ _config_lock = threading.Lock()
class LangSmithTracingConfig(BaseModel):
"""Configuration for LangSmith tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...)
api_key: str | None = Field(...)
project: str = Field(...)
@ -26,6 +28,8 @@ class LangSmithTracingConfig(BaseModel):
class LangfuseTracingConfig(BaseModel):
"""Configuration for Langfuse tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...)
public_key: str | None = Field(...)
secret_key: str | None = Field(...)
@ -50,6 +54,8 @@ class LangfuseTracingConfig(BaseModel):
class TracingConfig(BaseModel):
"""Tracing configuration for supported providers."""
model_config = ConfigDict(frozen=True)
langsmith: LangSmithTracingConfig = Field(...)
langfuse: LangfuseTracingConfig = Field(...)

View File

@ -2,7 +2,7 @@ import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
@ -39,7 +39,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
Returns:
A chat model instance.
"""
config = get_app_config()
config = AppConfig.current()
if name is None:
name = config.models[0].name
model_config = config.get_model_config(name)

View File

@ -24,7 +24,7 @@ from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
@ -138,7 +138,7 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
3. Default InMemorySaver
"""
config = get_app_config()
config = AppConfig.current()
# Legacy: standalone checkpointer config takes precedence
if config.checkpointer is not None:

View File

@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str
@ -113,25 +113,10 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None:
return _checkpointer
# Ensure app config is loaded before checking checkpointer config
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
try:
config = AppConfig.current().checkpointer
except (LookupError, FileNotFoundError):
config = None
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
@ -180,7 +165,7 @@ def checkpointer_context() -> Iterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
config = AppConfig.current()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver

View File

@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
@ -162,17 +164,14 @@ async def run_agent(
# 3. Build the agent
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Construct typed context for the agent run.
# LangGraph's astream(context=...) injects this into Runtime.context
# so middleware/tools can access it via resolve_context().
deer_flow_context = DeerFlowContext(
app_config=AppConfig.current(),
thread_id=thread_id,
)
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
@ -224,7 +223,7 @@ async def run_agent(
if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
async for chunk in agent.astream(graph_input, config=runnable_config, context=deer_flow_context, stream_mode=single_mode):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
@ -235,6 +234,7 @@ async def run_agent(
async for item in agent.astream(
graph_input,
config=runnable_config,
context=deer_flow_context,
stream_mode=lg_modes,
subgraphs=stream_subgraphs,
):

View File

@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@ -100,7 +100,7 @@ async def make_store() -> AsyncIterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case).
"""
config = get_app_config()
config = AppConfig.current()
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore

View File

@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@ -115,19 +115,10 @@ def get_store() -> BaseStore:
if _store is not None:
return _store
# Lazily load app config, mirroring the checkpointer singleton pattern so
# that tests that set the global checkpointer config explicitly remain isolated.
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config()
try:
config = AppConfig.current().checkpointer
except (LookupError, FileNotFoundError):
config = None
if config is None:
from langgraph.store.memory import InMemoryStore
@ -176,7 +167,7 @@ def store_context() -> Iterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
config = AppConfig.current()
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore

View File

@ -17,7 +17,7 @@ import contextlib
import logging
from collections.abc import AsyncIterator
from deerflow.config.stream_bridge_config import get_stream_bridge_config
from deerflow.config.app_config import AppConfig
from .base import StreamBridge
@ -32,7 +32,7 @@ async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
provided and nothing is set globally.
"""
if config is None:
config = get_stream_bridge_config()
config = AppConfig.current().stream_bridge
if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge

View File

@ -29,9 +29,9 @@ class LocalSandboxProvider(SandboxProvider):
# Map skills container path to local skills directory
try:
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
config = get_app_config()
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path

View File

@ -6,6 +6,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.sandbox import get_sandbox_provider
logger = logging.getLogger(__name__)
@ -49,15 +50,15 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
return sandbox_id
@override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
# Skip acquisition if lazy_init is enabled
if self._lazy_init:
return super().before_agent(state, runtime)
# Eager initialization (original behavior)
if "sandbox" not in state or state["sandbox"] is None:
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
thread_id = runtime.context.thread_id
if not thread_id:
return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
@ -65,7 +66,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
return super().before_agent(state, runtime)
@override
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
sandbox = state.get("sandbox")
if sandbox is not None:
sandbox_id = sandbox["sandbox_id"]
@ -73,11 +74,5 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
get_sandbox_provider().release(sandbox_id)
return None
if (runtime.context or {}).get("sandbox_id") is not None:
sandbox_id = runtime.context.get("sandbox_id")
logger.info(f"Releasing sandbox {sandbox_id} from context")
get_sandbox_provider().release(sandbox_id)
return None
# No sandbox to release
return super().after_agent(state, runtime)

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.sandbox.sandbox import Sandbox
@ -50,7 +50,7 @@ def get_sandbox_provider(**kwargs) -> SandboxProvider:
"""
global _default_sandbox_provider
if _default_sandbox_provider is None:
config = get_app_config()
config = AppConfig.current()
cls = resolve_class(config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls(**kwargs)
return _default_sandbox_provider

View File

@ -1,6 +1,6 @@
"""Security helpers for sandbox capability gating."""
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
_LOCAL_SANDBOX_PROVIDER_MARKERS = (
"deerflow.sandbox.local:LocalSandboxProvider",
@ -23,7 +23,7 @@ LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE = (
def uses_local_sandbox_provider(config=None) -> bool:
"""Return True when the active sandbox provider is the host-local provider."""
if config is None:
config = get_app_config()
config = AppConfig.current()
sandbox_cfg = getattr(config, "sandbox", None)
sandbox_use = getattr(sandbox_cfg, "use", "")
@ -35,7 +35,7 @@ def uses_local_sandbox_provider(config=None) -> bool:
def is_host_bash_allowed(config=None) -> bool:
"""Return whether host bash execution is explicitly allowed."""
if config is None:
config = get_app_config()
config = AppConfig.current()
sandbox_cfg = getattr(config, "sandbox", None)
if sandbox_cfg is None:

View File

@ -7,7 +7,7 @@ from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import (
SandboxError,
@ -50,9 +50,7 @@ def _get_skills_container_path() -> str:
if cached is not None:
return cached
try:
from deerflow.config import get_app_config
value = get_app_config().skills.container_path
value = AppConfig.current().skills.container_path
_get_skills_container_path._cached = value # type: ignore[attr-defined]
return value
except Exception:
@ -71,9 +69,7 @@ def _get_skills_host_path() -> str | None:
if cached is not None:
return cached
try:
from deerflow.config import get_app_config
config = get_app_config()
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
if skills_path.exists():
value = str(skills_path)
@ -132,9 +128,7 @@ def _get_custom_mounts():
try:
from pathlib import Path
from deerflow.config import get_app_config
config = get_app_config()
config = AppConfig.current()
mounts = []
if config.sandbox and config.sandbox.mounts:
# Only include mounts whose host_path exists, consistent with
@ -275,9 +269,7 @@ def _get_mcp_allowed_paths() -> list[str]:
"""Get the list of allowed paths from MCP config for file system server."""
allowed_paths = []
try:
from deerflow.config.extensions_config import get_extensions_config
extensions_config = get_extensions_config()
extensions_config = AppConfig.current().extensions
for _, server in extensions_config.mcp_servers.items():
if not server.enabled:
@ -302,7 +294,7 @@ def _get_mcp_allowed_paths() -> list[str]:
def _get_tool_config_int(name: str, key: str, default: int) -> int:
try:
tool_config = get_app_config().get_tool_config(name)
tool_config = AppConfig.current().get_tool_config(name)
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
@ -810,8 +802,6 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
if sandbox is None:
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
return sandbox
@ -846,16 +836,12 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id)
if sandbox is not None:
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
# Sandbox was released, fall through to acquire new one
# Lazy acquisition: get thread_id and acquire sandbox
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None
if thread_id is None:
thread_id = runtime.context.thread_id
if not thread_id:
raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider()
@ -869,8 +855,6 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox is None:
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
@ -1012,18 +996,14 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
sandbox_cfg = AppConfig.current().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
ensure_thread_directories_exist(runtime)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
sandbox_cfg = AppConfig.current().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
@ -1063,9 +1043,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
return "(empty)"
output = "\n".join(children)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
sandbox_cfg = AppConfig.current().sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
@ -1236,9 +1214,7 @@ def read_file_tool(
if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
sandbox_cfg = AppConfig.current().sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
except Exception:
max_chars = 50000

View File

@ -42,9 +42,9 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
if skills_path is None:
if use_config:
try:
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
config = get_app_config()
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
except Exception:
# Fallback to default if config fails

View File

@ -9,7 +9,7 @@ from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.skills.loader import load_skills
from deerflow.skills.validation import _validate_skill_frontmatter
@ -21,7 +21,7 @@ _SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path:
return get_app_config().skills.get_skills_path()
return AppConfig.current().skills.get_skills_path()
def get_public_skills_dir() -> Path:

View File

@ -7,7 +7,7 @@ import logging
import re
from dataclasses import dataclass
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@ -47,7 +47,7 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try:
config = get_app_config()
config = AppConfig.current()
model_name = config.skill_evolution.moderation_model_name
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
response = await model.ainvoke(

View File

@ -24,9 +24,9 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
return None
# Apply timeout override from config.yaml (lazy import to avoid circular deps)
from deerflow.config.subagents_config import get_subagents_app_config
from deerflow.config.app_config import AppConfig
app_config = get_subagents_app_config()
app_config = AppConfig.current().subagents
effective_timeout = app_config.get_timeout_for(name)
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)

View File

@ -34,7 +34,7 @@ def _normalize_presented_filepath(
if runtime.state is None:
raise ValueError("Thread runtime state is not available")
thread_id = runtime.context.get("thread_id") if runtime.context else None
thread_id = runtime.context.thread_id
if not thread_id:
raise ValueError("Thread ID is not available in runtime context")

View File

@ -24,7 +24,7 @@ def setup_agent(
description: One-line description of what the agent does.
"""
agent_name: str | None = runtime.context.get("agent_name") if runtime.context else None
agent_name: str | None = runtime.context.agent_name
try:
paths = get_paths()

View File

@ -92,9 +92,7 @@ async def task_tool(
if runtime is not None:
sandbox_state = runtime.state.get("sandbox")
thread_data = runtime.state.get("thread_data")
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id")
thread_id = runtime.context.thread_id
# Try to get parent model from configurable
metadata = runtime.config.get("metadata", {})

View File

@ -45,9 +45,7 @@ def _get_lock(name: str) -> asyncio.Lock:
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
if runtime is None:
return None
if runtime.context and runtime.context.get("thread_id"):
return runtime.context.get("thread_id")
return runtime.config.get("configurable", {}).get("thread_id")
return runtime.context.thread_id or None
def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]:

View File

@ -2,7 +2,7 @@ import logging
from langchain.tools import BaseTool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
@ -52,7 +52,7 @@ def get_available_tools(
Returns:
List of available tools.
"""
config = get_app_config()
config = AppConfig.current()
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
# Do not expose host bash by default when LocalSandboxProvider is active.
@ -123,10 +123,9 @@ def get_available_tools(
# Add invoke_acp_agent tool if any ACP agents are configured
acp_tools: list[BaseTool] = []
try:
from deerflow.config.acp_config import get_acp_agents
from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool
acp_agents = get_acp_agents()
acp_agents = AppConfig.current().acp_agents
if acp_agents:
acp_tools.append(build_invoke_acp_agent_tool(acp_agents))
logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})")

View File

@ -294,9 +294,9 @@ def _get_pdf_converter() -> str:
fall through to unexpected behaviour.
"""
try:
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
cfg = get_app_config()
cfg = AppConfig.current()
uploads_cfg = getattr(cfg, "uploads", None)
if uploads_cfg is not None:
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()

View File

@ -6,17 +6,20 @@ import pytest
import yaml
from pydantic import ValidationError
from deerflow.config.acp_config import ACPAgentConfig, get_acp_agents, load_acp_config_from_dict
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def setup_function():
"""Reset ACP config before each test."""
load_acp_config_from_dict({})
def _make_config(acp_agents: dict | None = None) -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
acp_agents={name: ACPAgentConfig(**cfg) for name, cfg in (acp_agents or {}).items()},
)
def test_load_acp_config_sets_agents():
load_acp_config_from_dict(
def test_acp_agents_via_app_config():
cfg = _make_config(
{
"claude_code": {
"command": "claude-code-acp",
@ -26,39 +29,33 @@ def test_load_acp_config_sets_agents():
}
}
)
agents = get_acp_agents()
agents = cfg.acp_agents
assert "claude_code" in agents
assert agents["claude_code"].command == "claude-code-acp"
assert agents["claude_code"].description == "Claude Code for coding tasks"
assert agents["claude_code"].model is None
def test_load_acp_config_multiple_agents():
load_acp_config_from_dict(
def test_multiple_agents():
cfg = _make_config(
{
"claude_code": {"command": "claude-code-acp", "args": [], "description": "Claude Code"},
"codex": {"command": "codex-acp", "args": ["--flag"], "description": "Codex CLI"},
}
)
agents = get_acp_agents()
agents = cfg.acp_agents
assert len(agents) == 2
assert agents["codex"].args == ["--flag"]
def test_load_acp_config_empty_clears_agents():
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
assert len(get_acp_agents()) == 1
load_acp_config_from_dict({})
assert len(get_acp_agents()) == 0
def test_empty_acp_agents():
cfg = _make_config({})
assert cfg.acp_agents == {}
def test_load_acp_config_none_clears_agents():
load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
assert len(get_acp_agents()) == 1
load_acp_config_from_dict(None)
assert get_acp_agents() == {}
def test_default_acp_agents_empty():
cfg = AppConfig(sandbox=SandboxConfig(use="test"))
assert cfg.acp_agents == {}
def test_acp_agent_config_defaults():
@ -79,8 +76,8 @@ def test_acp_agent_config_env_default_is_empty():
assert cfg.env == {}
def test_load_acp_config_preserves_env():
load_acp_config_from_dict(
def test_acp_agent_preserves_env():
cfg = _make_config(
{
"codex": {
"command": "codex-acp",
@ -90,8 +87,7 @@ def test_load_acp_config_preserves_env():
}
}
)
cfg = get_acp_agents()["codex"]
assert cfg.env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
assert cfg.acp_agents["codex"].env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
def test_acp_agent_config_with_model():
@ -115,13 +111,7 @@ def test_acp_agent_config_missing_description_raises():
ACPAgentConfig(command="my-agent")
def test_get_acp_agents_returns_empty_by_default():
"""After clearing, should return empty dict."""
load_acp_config_from_dict({})
assert get_acp_agents() == {}
def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, monkeypatch):
def test_app_config_from_file_with_acp_agents(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
extensions_path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
@ -157,9 +147,9 @@ def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, mo
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config_path.write_text(yaml.safe_dump(config_with_acp), encoding="utf-8")
AppConfig.from_file(str(config_path))
assert set(get_acp_agents()) == {"codex"}
app = AppConfig.from_file(str(config_path))
assert set(app.acp_agents) == {"codex"}
config_path.write_text(yaml.safe_dump(config_without_acp), encoding="utf-8")
AppConfig.from_file(str(config_path))
assert get_acp_agents() == {}
app = AppConfig.from_file(str(config_path))
assert app.acp_agents == {}

View File

@ -1,12 +1,11 @@
from __future__ import annotations
import json
import os
from pathlib import Path
import yaml
from deerflow.config.app_config import get_app_config, reset_app_config
from deerflow.config.app_config import AppConfig
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
@ -32,50 +31,61 @@ def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
def test_get_app_config_reloads_when_file_changes(tmp_path, monkeypatch):
def test_init_then_get(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config(config_path, model_name="first-model", supports_thinking=False)
_write_config(config_path, model_name="test-model", supports_thinking=False)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
try:
initial = get_app_config()
assert initial.models[0].supports_thinking is False
config = AppConfig.from_file(str(config_path))
AppConfig.init(config)
_write_config(config_path, model_name="first-model", supports_thinking=True)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
reloaded = get_app_config()
assert reloaded.models[0].supports_thinking is True
assert reloaded is not initial
finally:
reset_app_config()
result = AppConfig.current()
assert result is config
assert result.models[0].name == "test-model"
def test_get_app_config_reloads_when_config_path_changes(tmp_path, monkeypatch):
config_a = tmp_path / "config-a.yaml"
config_b = tmp_path / "config-b.yaml"
def test_init_replaces_previous(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config(config_a, model_name="model-a", supports_thinking=False)
_write_config(config_b, model_name="model-b", supports_thinking=True)
_write_config(config_path, model_name="model-a", supports_thinking=False)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_a))
reset_app_config()
try:
first = get_app_config()
assert first.models[0].name == "model-a"
config_a = AppConfig.from_file(str(config_path))
AppConfig.init(config_a)
assert AppConfig.current().models[0].name == "model-a"
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_b))
second = get_app_config()
assert second.models[0].name == "model-b"
assert second is not first
finally:
reset_app_config()
_write_config(config_path, model_name="model-b", supports_thinking=True)
config_b = AppConfig.from_file(str(config_path))
AppConfig.init(config_b)
assert AppConfig.current().models[0].name == "model-b"
assert AppConfig.current() is config_b
def test_config_version_check(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
config_path.write_text(
yaml.safe_dump(
{
"config_version": 1,
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [],
}
),
encoding="utf-8",
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config = AppConfig.from_file(str(config_path))
assert config is not None

View File

@ -5,25 +5,21 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import deerflow.config.app_config as app_config_module
from deerflow.config.checkpointer_config import (
CheckpointerConfig,
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
def _make_config(checkpointer: CheckpointerConfig | None = None) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), checkpointer=checkpointer)
@pytest.fixture(autouse=True)
def reset_state():
"""Reset singleton state before each test."""
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
yield
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
@ -33,24 +29,18 @@ def reset_state():
class TestCheckpointerConfig:
def test_load_memory_config(self):
load_checkpointer_config_from_dict({"type": "memory"})
config = get_checkpointer_config()
assert config is not None
def test_memory_config(self):
config = CheckpointerConfig(type="memory")
assert config.type == "memory"
assert config.connection_string is None
def test_load_sqlite_config(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
config = get_checkpointer_config()
assert config is not None
def test_sqlite_config(self):
config = CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db")
assert config.type == "sqlite"
assert config.connection_string == "/tmp/test.db"
def test_load_postgres_config(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
config = get_checkpointer_config()
assert config is not None
def test_postgres_config(self):
config = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
assert config.type == "postgres"
assert config.connection_string == "postgresql://localhost/db"
@ -58,14 +48,9 @@ class TestCheckpointerConfig:
config = CheckpointerConfig(type="memory")
assert config.connection_string is None
def test_set_config_to_none(self):
load_checkpointer_config_from_dict({"type": "memory"})
set_checkpointer_config(None)
assert get_checkpointer_config() is None
def test_invalid_type_raises(self):
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
CheckpointerConfig(type="unknown")
# ---------------------------------------------------------------------------
@ -78,58 +63,78 @@ class TestGetCheckpointer:
"""get_checkpointer should return InMemorySaver when not configured."""
from langgraph.checkpoint.memory import InMemorySaver
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
with patch.object(AppConfig, "current", return_value=_make_config()):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_returns_in_memory_saver_when_config_not_found(self):
from langgraph.checkpoint.memory import InMemorySaver
with patch.object(AppConfig, "current", side_effect=FileNotFoundError):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_memory_returns_in_memory_saver(self):
load_checkpointer_config_from_dict({"type": "memory"})
from langgraph.checkpoint.memory import InMemorySaver
cp = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp = get_checkpointer()
assert isinstance(cp, InMemorySaver)
def test_memory_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
cp2 = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp1 = get_checkpointer()
cp2 = get_checkpointer()
assert cp1 is cp2
def test_reset_clears_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
reset_checkpointer()
cp2 = get_checkpointer()
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp1 = get_checkpointer()
reset_checkpointer()
cp2 = get_checkpointer()
assert cp1 is not cp2
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}),
):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
get_checkpointer()
def test_postgres_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}),
):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
get_checkpointer()
def test_postgres_raises_when_connection_string_missing(self):
load_checkpointer_config_from_dict({"type": "postgres"})
cfg = _make_config(CheckpointerConfig(type="postgres"))
mock_saver = MagicMock()
mock_module = MagicMock()
mock_module.PostgresSaver = mock_saver
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}),
):
reset_checkpointer()
with pytest.raises(ValueError, match="connection_string is required"):
get_checkpointer()
def test_sqlite_creates_saver(self):
"""SQLite checkpointer is created when package is available."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db"))
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
@ -142,7 +147,10 @@ class TestGetCheckpointer:
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
):
reset_checkpointer()
cp = get_checkpointer()
@ -152,7 +160,7 @@ class TestGetCheckpointer:
def test_postgres_creates_saver(self):
"""Postgres checkpointer is created when packages are available."""
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
@ -165,7 +173,10 @@ class TestGetCheckpointer:
mock_pg_module = MagicMock()
mock_pg_module.PostgresSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}),
):
reset_checkpointer()
cp = get_checkpointer()
@ -195,7 +206,7 @@ class TestAsyncCheckpointer:
mock_module.AsyncSqliteSaver = mock_saver_cls
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.object(AppConfig, "current", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
patch(
@ -221,12 +232,10 @@ class TestAsyncCheckpointer:
class TestAppConfigLoadsCheckpointer:
def test_load_checkpointer_section(self):
"""load_checkpointer_config_from_dict populates the global config."""
set_checkpointer_config(None)
load_checkpointer_config_from_dict({"type": "memory"})
cfg = get_checkpointer_config()
assert cfg is not None
assert cfg.type == "memory"
"""AppConfig with checkpointer section has the correct config."""
cfg = _make_config(CheckpointerConfig(type="memory"))
assert cfg.checkpointer is not None
assert cfg.checkpointer.type == "memory"
# ---------------------------------------------------------------------------
@ -237,68 +246,6 @@ class TestAppConfigLoadsCheckpointer:
class TestClientCheckpointerFallback:
def test_client_uses_config_checkpointer_when_none_provided(self):
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=None)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert "checkpointer" in captured_kwargs
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
def test_client_explicit_checkpointer_takes_precedence(self):
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
explicit_cp = MagicMock()
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=explicit_cp)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert captured_kwargs["checkpointer"] is explicit_cp
# This is a structural test — verifying the fallback path exists.
cfg = _make_config(CheckpointerConfig(type="memory"))
assert cfg.checkpointer is not None

View File

@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.config.app_config import AppConfig
class TestCheckpointerNoneFix:
"""Tests that checkpointer context managers return InMemorySaver instead of None."""
@ -14,12 +16,12 @@ class TestCheckpointerNoneFix:
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
# Mock get_app_config to return a config with checkpointer=None and database=None
# Mock AppConfig.current to return a config with checkpointer=None and database=None
mock_config = MagicMock()
mock_config.checkpointer = None
mock_config.database = None
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
with patch.object(AppConfig, "current", return_value=mock_config):
async with make_checkpointer() as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
@ -38,11 +40,11 @@ class TestCheckpointerNoneFix:
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
from deerflow.runtime.checkpointer.provider import checkpointer_context
# Mock get_app_config to return a config with checkpointer=None
# Mock AppConfig.get to return a config with checkpointer=None
mock_config = MagicMock()
mock_config.checkpointer = None
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
with patch.object(AppConfig, "current", return_value=mock_config):
with checkpointer_context() as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None

View File

@ -18,6 +18,7 @@ from app.gateway.routers.models import ModelResponse, ModelsListResponse
from app.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
from app.gateway.routers.uploads import UploadResponse
from deerflow.client import DeerFlowClient
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import Paths
from deerflow.uploads.manager import PathTraversalError
@ -44,7 +45,7 @@ def mock_app_config():
@pytest.fixture
def client(mock_app_config):
"""Create a DeerFlowClient with mocked config loading."""
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
return DeerFlowClient()
@ -66,7 +67,7 @@ class TestClientInit:
def test_custom_params(self, mock_app_config):
mock_middleware = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
assert c._model_name == "gpt-4"
assert c._thinking_enabled is False
@ -77,7 +78,7 @@ class TestClientInit:
assert c._middlewares == [mock_middleware]
def test_invalid_agent_name(self, mock_app_config):
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="invalid name with spaces!")
with pytest.raises(ValueError, match="Invalid agent name"):
@ -85,15 +86,17 @@ class TestClientInit:
def test_custom_config_path(self, mock_app_config):
with (
patch("deerflow.client.reload_app_config") as mock_reload,
patch("deerflow.client.get_app_config", return_value=mock_app_config),
patch.object(AppConfig, "from_file", return_value=mock_app_config) as mock_from_file,
patch.object(AppConfig, "init") as mock_init,
patch.object(AppConfig, "current", return_value=mock_app_config),
):
DeerFlowClient(config_path="/tmp/custom.yaml")
mock_reload.assert_called_once_with("/tmp/custom.yaml")
mock_from_file.assert_called_once_with("/tmp/custom.yaml")
mock_init.assert_called_once_with(mock_app_config)
def test_checkpointer_stored(self, mock_app_config):
cp = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
c = DeerFlowClient(checkpointer=cp)
assert c._checkpointer is cp
@ -249,8 +252,8 @@ class TestStream:
# Verify context passed to agent.stream
agent.stream.assert_called_once()
call_kwargs = agent.stream.call_args.kwargs
assert call_kwargs["context"]["thread_id"] == "t1"
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
ctx = call_kwargs["context"]
assert ctx.app_config is client._app_config
def test_custom_mode_is_normalized_to_string(self, client):
"""stream() forwards custom events even when the mode is not a plain string."""
@ -1089,7 +1092,7 @@ class TestMcpConfig:
ext_config = MagicMock()
ext_config.mcp_servers = {"github": server}
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
with patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)):
result = client.get_mcp_config()
assert "mcp_servers" in result
@ -1114,10 +1117,12 @@ class TestMcpConfig:
# Pre-set agent to verify it gets invalidated
client._agent = MagicMock()
# Set initial AppConfig with current extensions
AppConfig.init(MagicMock(extensions=current_config))
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
):
result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}})
@ -1179,8 +1184,8 @@ class TestSkillsManagement:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
result = client.update_skill("test-skill", enabled=False)
assert result["enabled"] is False
@ -1314,35 +1319,40 @@ class TestMemoryManagement:
assert result == data
def test_get_memory_config(self, client):
config = MagicMock()
config.enabled = True
config.storage_path = ".deer-flow/memory.json"
config.debounce_seconds = 30
config.max_facts = 100
config.fact_confidence_threshold = 0.7
config.injection_enabled = True
config.max_injection_tokens = 2000
mem_config = MagicMock()
mem_config.enabled = True
mem_config.storage_path = ".deer-flow/memory.json"
mem_config.debounce_seconds = 30
mem_config.max_facts = 100
mem_config.fact_confidence_threshold = 0.7
mem_config.injection_enabled = True
mem_config.max_injection_tokens = 2000
with patch("deerflow.config.memory_config.get_memory_config", return_value=config):
app_cfg = MagicMock()
app_cfg.memory = mem_config
with patch.object(AppConfig, "current", return_value=app_cfg):
result = client.get_memory_config()
assert result["enabled"] is True
assert result["max_facts"] == 100
def test_get_memory_status(self, client):
config = MagicMock()
config.enabled = True
config.storage_path = ".deer-flow/memory.json"
config.debounce_seconds = 30
config.max_facts = 100
config.fact_confidence_threshold = 0.7
config.injection_enabled = True
config.max_injection_tokens = 2000
mem_config = MagicMock()
mem_config.enabled = True
mem_config.storage_path = ".deer-flow/memory.json"
mem_config.debounce_seconds = 30
mem_config.max_facts = 100
mem_config.fact_confidence_threshold = 0.7
mem_config.injection_enabled = True
mem_config.max_injection_tokens = 2000
app_cfg = MagicMock()
app_cfg.memory = mem_config
data = {"version": "1.0", "facts": []}
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
patch.object(AppConfig, "current", return_value=app_cfg),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=data),
):
result = client.get_memory_status()
@ -1798,10 +1808,10 @@ class TestScenarioConfigManagement:
reloaded_config.mcp_servers = {"my-mcp": reloaded_server}
client._agent = MagicMock() # Simulate existing agent
AppConfig.init(MagicMock(extensions=current_config))
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
):
mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}})
assert "my-mcp" in mcp_result["mcp_servers"]
@ -1830,8 +1840,8 @@ class TestScenarioConfigManagement:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [toggled]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
skill_result = client.update_skill("code-gen", enabled=False)
assert skill_result["enabled"] is False
@ -2019,8 +2029,10 @@ class TestScenarioMemoryWorkflow:
refreshed = client.reload_memory()
assert len(refreshed["facts"]) == 2
app_cfg = MagicMock()
app_cfg.memory = config
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=config),
patch.object(AppConfig, "current", return_value=app_cfg),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data),
):
status = client.get_memory_status()
@ -2083,8 +2095,8 @@ class TestScenarioSkillInstallAndUse:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
toggled = client.update_skill("my-analyzer", enabled=False)
assert toggled["enabled"] is False
@ -2216,7 +2228,7 @@ class TestGatewayConformance:
model.supports_thinking = False
mock_app_config.models = [model]
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
client = DeerFlowClient()
result = client.list_models()
@ -2235,7 +2247,7 @@ class TestGatewayConformance:
mock_app_config.models = [model]
mock_app_config.get_model_config.return_value = model
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config):
client = DeerFlowClient()
result = client.get_model("test-model")
@ -2305,7 +2317,7 @@ class TestGatewayConformance:
ext_config = MagicMock()
ext_config.mcp_servers = {"test": server}
with patch("deerflow.client.get_extensions_config", return_value=ext_config):
with patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)):
result = client.get_mcp_config()
parsed = McpConfigResponse(**result)
@ -2331,9 +2343,9 @@ class TestGatewayConformance:
config_file.write_text("{}")
with (
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.reload_extensions_config", return_value=ext_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=ext_config)),
):
result = client.update_mcp_config({"srv": server.model_dump.return_value})
@ -2364,7 +2376,10 @@ class TestGatewayConformance:
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
app_cfg = MagicMock()
app_cfg.memory = mem_cfg
with patch.object(AppConfig, "current", return_value=app_cfg):
result = client.get_memory_config()
parsed = MemoryConfigResponse(**result)
@ -2381,6 +2396,8 @@ class TestGatewayConformance:
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
app_cfg = MagicMock()
app_cfg.memory = mem_cfg
memory_data = {
"version": "1.0",
"lastUpdated": "",
@ -2398,7 +2415,7 @@ class TestGatewayConformance:
}
with (
patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg),
patch.object(AppConfig, "current", return_value=app_cfg),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data),
):
result = client.get_memory_status()
@ -2689,8 +2706,8 @@ class TestConfigUpdateErrors:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], []]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
with pytest.raises(RuntimeError, match="disappeared"):
client.update_skill("ghost-skill", enabled=False)
@ -3069,10 +3086,10 @@ class TestBugAgentInvalidationInconsistency:
config_file = Path(tmp) / "ext.json"
config_file.write_text("{}")
AppConfig.init(MagicMock(extensions=current_config))
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded)),
):
client.update_mcp_config({})
@ -3104,8 +3121,8 @@ class TestBugAgentInvalidationInconsistency:
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()),
):
client.update_skill("s1", enabled=False)

View File

@ -0,0 +1,73 @@
"""Verify that all sub-config Pydantic models are frozen (immutable).
Frozen models reject attribute assignment after construction, raising
pydantic.ValidationError. This test collects every BaseModel subclass
defined in the deerflow.config package and asserts that mutation is
blocked.
"""
import inspect
import pkgutil
import pytest
from pydantic import BaseModel, ValidationError
import deerflow.config as config_pkg
def _collect_config_models() -> list[type[BaseModel]]:
"""Walk deerflow.config.* and return all concrete BaseModel subclasses."""
import importlib
models: list[type[BaseModel]] = []
package_path = config_pkg.__path__
package_prefix = config_pkg.__name__ + "."
for _importer, modname, _ispkg in pkgutil.walk_packages(package_path, prefix=package_prefix):
try:
mod = importlib.import_module(modname)
except Exception:
continue
for _name, obj in inspect.getmembers(mod, inspect.isclass):
if (
issubclass(obj, BaseModel)
and obj is not BaseModel
and obj.__module__ == mod.__name__
):
models.append(obj)
return models
_EXCLUDED: set[str] = set()
_ALL_MODELS = [m for m in _collect_config_models() if m.__name__ not in _EXCLUDED]
# Sanity: make sure we actually collected a meaningful set.
assert len(_ALL_MODELS) >= 15, f"Expected at least 15 config models, found {len(_ALL_MODELS)}: {[m.__name__ for m in _ALL_MODELS]}"
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
def test_config_model_is_frozen(model_cls: type[BaseModel]):
"""Every sub-config model must have frozen=True in its model_config."""
cfg = model_cls.model_config
assert cfg.get("frozen") is True, (
f"{model_cls.__name__} is not frozen. "
f"Add `model_config = ConfigDict(frozen=True)` or add `frozen=True` to the existing ConfigDict."
)
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
def test_config_model_rejects_mutation(model_cls: type[BaseModel]):
"""Constructing then mutating any field must raise ValidationError."""
# Build a minimal instance -- use model_construct to skip validation for
# required fields, then pick the first field to try mutating.
fields = list(model_cls.model_fields.keys())
if not fields:
pytest.skip(f"{model_cls.__name__} has no fields")
instance = model_cls.model_construct()
first_field = fields[0]
with pytest.raises(ValidationError):
setattr(instance, first_field, "MUTATED")

View File

@ -3,12 +3,14 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
import yaml
from fastapi.testclient import TestClient
from deerflow.config.app_config import AppConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@ -331,7 +333,7 @@ class TestMemoryFilePath:
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
path = storage._get_memory_file_path(None)
@ -344,7 +346,7 @@ class TestMemoryFilePath:
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
path = storage._get_memory_file_path("code-reviewer")
@ -356,7 +358,7 @@ class TestMemoryFilePath:
with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
path_global = storage._get_memory_file_path(None)

View File

@ -0,0 +1,86 @@
"""Tests for DeerFlowContext and resolve_context()."""
from dataclasses import FrozenInstanceError
from unittest.mock import MagicMock, patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
from deerflow.config.sandbox_config import SandboxConfig
def _make_config(**overrides) -> AppConfig:
defaults = {"sandbox": SandboxConfig(use="test")}
defaults.update(overrides)
return AppConfig(**defaults)
class TestDeerFlowContext:
def test_frozen(self):
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
with pytest.raises(FrozenInstanceError):
ctx.app_config = _make_config()
def test_fields(self):
config = _make_config()
ctx = DeerFlowContext(app_config=config, thread_id="t1", agent_name="test-agent")
assert ctx.thread_id == "t1"
assert ctx.agent_name == "test-agent"
assert ctx.app_config is config
def test_agent_name_default(self):
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
assert ctx.agent_name is None
def test_thread_id_required(self):
with pytest.raises(TypeError):
DeerFlowContext(app_config=_make_config()) # type: ignore[call-arg]
class TestResolveContext:
def test_returns_typed_context_directly(self):
"""Gateway/Client path: runtime.context is DeerFlowContext → return as-is."""
config = _make_config()
ctx = DeerFlowContext(app_config=config, thread_id="t1")
runtime = MagicMock()
runtime.context = ctx
assert resolve_context(runtime) is ctx
def test_fallback_from_configurable(self):
"""LangGraph Server path: runtime.context is None → construct from ContextVar + configurable."""
runtime = MagicMock()
runtime.context = None
config = _make_config()
with (
patch.object(AppConfig, "current", return_value=config),
patch("langgraph.config.get_config", return_value={"configurable": {"thread_id": "t2", "agent_name": "ag"}}),
):
ctx = resolve_context(runtime)
assert ctx.thread_id == "t2"
assert ctx.agent_name == "ag"
assert ctx.app_config is config
def test_fallback_empty_configurable(self):
"""LangGraph Server path with no thread_id in configurable → empty string."""
runtime = MagicMock()
runtime.context = None
config = _make_config()
with (
patch.object(AppConfig, "current", return_value=config),
patch("langgraph.config.get_config", return_value={"configurable": {}}),
):
ctx = resolve_context(runtime)
assert ctx.thread_id == ""
assert ctx.agent_name is None
def test_fallback_from_dict_context(self):
"""Legacy path: runtime.context is a dict → extract from dict directly."""
runtime = MagicMock()
runtime.context = {"thread_id": "old-dict", "agent_name": "from-dict"}
config = _make_config()
with patch.object(AppConfig, "current", return_value=config):
ctx = resolve_context(runtime)
assert ctx.thread_id == "old-dict"
assert ctx.agent_name == "from-dict"
assert ctx.app_config is config

View File

@ -5,11 +5,13 @@ from unittest.mock import MagicMock, patch
import pytest
from deerflow.config.app_config import AppConfig
@pytest.fixture
def mock_app_config():
"""Mock the app config to return tool configurations."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch.object(AppConfig, "current") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 5,
@ -67,7 +69,7 @@ class TestWebSearchTool:
def test_search_with_custom_config(self, mock_exa_client):
"""Test search respects custom configuration values."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch.object(AppConfig, "current") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 10,
@ -195,7 +197,7 @@ class TestWebFetchTool:
def test_fetch_reads_web_fetch_config(self, mock_exa_client):
"""Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch.object(AppConfig, "current") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "exa-fetch-key"}
mock_config.return_value.get_tool_config.return_value = tool_config
@ -215,7 +217,7 @@ class TestWebFetchTool:
def test_fetch_uses_independent_api_key(self, mock_exa_client):
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch.object(AppConfig, "current") as mock_config:
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
mock_exa_cls.return_value = mock_exa_client
fetch_config = MagicMock()

View File

@ -3,10 +3,12 @@
import json
from unittest.mock import MagicMock, patch
from deerflow.config.app_config import AppConfig
class TestWebSearchTool:
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
@patch("deerflow.community.firecrawl.tools.get_app_config")
@patch.object(AppConfig, "current")
def test_search_uses_web_search_config(self, mock_get_app_config, mock_firecrawl_cls):
search_config = MagicMock()
search_config.model_extra = {"api_key": "firecrawl-search-key", "max_results": 7}
@ -36,7 +38,7 @@ class TestWebSearchTool:
class TestWebFetchTool:
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
@patch("deerflow.community.firecrawl.tools.get_app_config")
@patch.object(AppConfig, "current")
def test_fetch_uses_web_fetch_config(self, mock_get_app_config, mock_firecrawl_cls):
fetch_config = MagicMock()
fetch_config.model_extra = {"api_key": "firecrawl-fetch-key"}

View File

@ -333,12 +333,17 @@ class TestGuardrailsConfig:
assert config.provider.use == "deerflow.guardrails.builtin:AllowlistProvider"
assert config.provider.config == {"denied_tools": ["bash"]}
def test_singleton_load_and_get(self):
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict, reset_guardrails_config
def test_guardrails_config_via_app_config(self):
from unittest.mock import patch
try:
load_guardrails_config_from_dict({"enabled": True, "provider": {"use": "test:Foo"}})
config = get_guardrails_config()
from deerflow.config.app_config import AppConfig
from deerflow.config.guardrails_config import GuardrailProviderConfig, GuardrailsConfig
from deerflow.config.sandbox_config import SandboxConfig
cfg = AppConfig(
sandbox=SandboxConfig(use="test"),
guardrails=GuardrailsConfig(enabled=True, provider=GuardrailProviderConfig(use="test:Foo")),
)
with patch.object(AppConfig, "current", return_value=cfg):
config = AppConfig.current().guardrails
assert config.enabled is True
finally:
reset_guardrails_config()

View File

@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch
from deerflow.community.infoquest import tools
from deerflow.community.infoquest.infoquest_client import InfoQuestClient
from deerflow.config.app_config import AppConfig
class TestInfoQuestClient:
@ -149,8 +150,8 @@ class TestInfoQuestClient:
mock_get_client.assert_called_once()
mock_client.fetch.assert_called_once_with("https://example.com")
@patch("deerflow.community.infoquest.tools.get_app_config")
def test_get_infoquest_client(self, mock_get_app_config):
@patch.object(AppConfig, "current")
def test_get_infoquest_client(self, mock_get):
"""Test _get_infoquest_client function with config."""
mock_config = MagicMock()
# Add image_search config to the side_effect
@ -159,7 +160,7 @@ class TestInfoQuestClient:
MagicMock(model_extra={"fetch_time": 10, "timeout": 30, "navigation_timeout": 60}), # web_fetch config
MagicMock(model_extra={"image_search_time_range": 7, "image_size": "l"}), # image_search config
]
mock_get_app_config.return_value = mock_config
mock_get.return_value = mock_config
client = tools._get_infoquest_client()

View File

@ -6,7 +6,8 @@ from types import SimpleNamespace
import pytest
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
from deerflow.tools.builtins.invoke_acp_agent_tool import (
_build_acp_mcp_servers,
_build_mcp_servers,
@ -18,7 +19,6 @@ from deerflow.tools.tools import get_available_tools
def test_build_mcp_servers_filters_disabled_and_maps_transports():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"]),
@ -40,11 +40,9 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
}
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_acp_mcp_servers_formats_list_payload():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
@ -77,7 +75,6 @@ def test_build_acp_mcp_servers_formats_list_payload():
]
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_permission_response_prefers_allow_once():
@ -669,25 +666,20 @@ async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch,
def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(monkeypatch):
from deerflow.config.acp_config import load_acp_config_from_dict
load_acp_config_from_dict(
{
"codex": {
"command": "codex-acp",
"args": [],
"description": "Codex CLI",
}
}
)
fake_config = SimpleNamespace(
tools=[],
models=[],
tool_search=SimpleNamespace(enabled=False),
acp_agents={
"codex": ACPAgentConfig(
command="codex-acp",
args=[],
description="Codex CLI",
)
},
get_model_config=lambda name: None,
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: fake_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: fake_config))
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
@ -695,5 +687,3 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
tools = get_available_tools(include_mcp=True, subagent_enabled=False)
assert "invoke_acp_agent" in [tool.name for tool in tools]
load_acp_config_from_dict({})

View File

@ -9,6 +9,7 @@ import pytest
import deerflow.community.jina_ai.jina_client as jina_client_module
from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.community.jina_ai.tools import web_fetch_tool
from deerflow.config.app_config import AppConfig
@pytest.fixture
@ -154,7 +155,7 @@ async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
assert result.startswith("Error:")
@ -170,7 +171,7 @@ async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
assert "Hello world" in result

View File

@ -40,7 +40,7 @@ def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
with caplog.at_level("WARNING"):
resolved = lead_agent_module._resolve_model_name("missing-model")
@ -57,7 +57,7 @@ def test_resolve_model_name_uses_default_when_none(monkeypatch):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
resolved = lead_agent_module._resolve_model_name(None)
@ -67,7 +67,7 @@ def test_resolve_model_name_uses_default_when_none(monkeypatch):
def test_resolve_model_name_raises_when_no_models_configured(monkeypatch):
app_config = _make_app_config([])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
with pytest.raises(
ValueError,
@ -81,7 +81,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
import deerflow.tools as tools_module
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
@ -128,7 +128,8 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
]
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
AppConfig.init(app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
@ -140,11 +141,10 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
)
app_config = _make_app_config([_make_model("default", supports_thinking=False)])
patched = app_config.model_copy(update={"summarization": SummarizationConfig(enabled=True, model_name="model-masswork")})
AppConfig.init(patched)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: patched))
from unittest.mock import MagicMock

View File

@ -4,12 +4,13 @@ from types import SimpleNamespace
import anyio
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.app_config import AppConfig
from deerflow.skills.types import Skill
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
assert prompt_module._build_custom_mounts_section() == ""
@ -20,7 +21,7 @@ def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
SimpleNamespace(container_path="/mnt/reference", read_only=True),
]
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
section = prompt_module._build_custom_mounts_section()
@ -37,7 +38,7 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
sandbox=SimpleNamespace(mounts=mounts),
skills=SimpleNamespace(container_path="/mnt/skills"),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
@ -55,7 +56,7 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
sandbox=SimpleNamespace(mounts=[]),
skills=SimpleNamespace(container_path="/mnt/skills"),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")

View File

@ -3,6 +3,7 @@ from types import SimpleNamespace
from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
from deerflow.config.agents_config import AgentConfig
from deerflow.config.app_config import AppConfig
from deerflow.skills.types import Skill
@ -58,11 +59,11 @@ def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
skills = [_make_skill("skill1")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr(
"deerflow.config.get_app_config",
lambda: SimpleNamespace(
AppConfig, "current",
staticmethod(lambda: SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
),
)),
)
result = get_skills_prompt_section(available_skills=None)
@ -72,11 +73,11 @@ def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
def test_get_skills_prompt_section_includes_self_evolution_rules_without_skills(monkeypatch):
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [])
monkeypatch.setattr(
"deerflow.config.get_app_config",
lambda: SimpleNamespace(
AppConfig, "current",
staticmethod(lambda: SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
),
)),
)
result = get_skills_prompt_section(available_skills=None)
@ -90,7 +91,7 @@ def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeyp
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
enabled_result = get_skills_prompt_section(available_skills=None)
assert "Skill Self-Evolution" in enabled_result
@ -106,7 +107,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
from deerflow.agents.lead_agent import agent as lead_agent_module
# Mock dependencies
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: MagicMock()))
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
@ -118,7 +119,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = MockModelConfig()
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_app_config))
captured_skills = []

View File

@ -1,5 +1,6 @@
from types import SimpleNamespace
from deerflow.config.app_config import AppConfig
from deerflow.tools.tools import get_available_tools
@ -21,7 +22,7 @@ def _make_config(*, allow_host_bash: bool, sandbox_use: str = "deerflow.sandbox.
def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=False))
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: _make_config(allow_host_bash=False)))
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
@ -34,7 +35,7 @@ def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
def test_get_available_tools_keeps_bash_when_explicitly_enabled(monkeypatch):
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: _make_config(allow_host_bash=True))
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: _make_config(allow_host_bash=True)))
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
@ -51,7 +52,7 @@ def test_get_available_tools_hides_renamed_host_bash_alias(monkeypatch):
allow_host_bash=False,
extra_tools=[SimpleNamespace(name="shell", group="bash", use="deerflow.sandbox.tools:bash_tool")],
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
@ -69,7 +70,7 @@ def test_get_available_tools_keeps_bash_for_aio_sandbox(monkeypatch):
allow_host_bash=False,
sandbox_use="deerflow.community.aio_sandbox:AioSandboxProvider",
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(
"deerflow.tools.tools.resolve_variable",
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
@ -312,7 +313,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
@ -334,7 +335,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
@ -358,7 +359,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
@ -382,7 +383,7 @@ class TestLocalSandboxProviderMounts:
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]

View File

@ -10,12 +10,22 @@ from deerflow.agents.middlewares.loop_detection_middleware import (
LoopDetectionMiddleware,
_hash_tool_calls,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _make_runtime(thread_id="test-thread"):
"""Build a minimal Runtime mock with context."""
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
runtime.context = _make_context(thread_id)
return runtime
@ -291,10 +301,10 @@ class TestLoopDetection:
assert isinstance(mw._lock, type(mw._lock))
def test_fallback_thread_id_when_missing(self):
"""When runtime context has no thread_id, should use 'default'."""
"""When runtime context has empty thread_id, should use 'default'."""
mw = LoopDetectionMiddleware(warn_threshold=2)
runtime = MagicMock()
runtime.context = {}
runtime.context = _make_context("")
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)

View File

@ -1,21 +1,20 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
def _make_config(**memory_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_make_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], correction_detected=True)
@ -56,7 +55,7 @@ def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> No
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_make_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)

View File

@ -11,7 +11,13 @@ from deerflow.agents.memory.storage import (
create_empty_memory,
get_memory_storage,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _app_config(**memory_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
class TestCreateEmptyMemory:
@ -53,7 +59,7 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
path = storage._get_memory_file_path(None)
assert path == tmp_path / "memory.json"
@ -87,7 +93,7 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
memory = storage.load()
assert isinstance(memory, dict)
@ -103,7 +109,7 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
result = storage.save(test_memory)
@ -122,7 +128,7 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
# First load
memory1 = storage.load()
@ -150,19 +156,19 @@ class TestGetMemoryStorage:
def test_returns_file_memory_storage_by_default(self):
"""Should return FileMemoryStorage by default."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
def test_falls_back_to_file_memory_storage_on_error(self):
"""Should fall back to FileMemoryStorage if configured storage fails to load."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="non.existent.StorageClass")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="non.existent.StorageClass")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
def test_returns_singleton_instance(self):
"""Should return the same instance on subsequent calls."""
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
storage1 = get_memory_storage()
storage2 = get_memory_storage()
assert storage1 is storage2
@ -173,11 +179,11 @@ class TestGetMemoryStorage:
def get_storage():
# get_memory_storage is called concurrently from multiple threads while
# get_memory_config is patched once around thread creation. This verifies
# AppConfig.get is patched once around thread creation. This verifies
# that the singleton initialization remains thread-safe.
results.append(get_memory_storage())
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
threads = [threading.Thread(target=get_storage) for _ in range(10)]
for t in threads:
t.start()
@ -191,13 +197,13 @@ class TestGetMemoryStorage:
def test_get_memory_storage_invalid_class_fallback(self):
"""Should fall back to FileMemoryStorage if the configured class is not actually a class."""
# Using a built-in function instead of a class
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="os.path.join")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="os.path.join")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)
def test_get_memory_storage_non_subclass_fallback(self):
"""Should fall back to FileMemoryStorage if the configured class is not a subclass of MemoryStorage."""
# Using 'dict' as a class that is not a MemoryStorage subclass
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="builtins.dict")):
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="builtins.dict")):
storage = get_memory_storage()
assert isinstance(storage, FileMemoryStorage)

View File

@ -10,7 +10,9 @@ from deerflow.agents.memory.updater import (
import_memory_data,
update_memory_fact,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, object]:
@ -31,11 +33,8 @@ def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, obje
}
def _memory_config(**overrides: object) -> MemoryConfig:
config = MemoryConfig()
for key, value in overrides.items():
setattr(config, key, value)
return config
def _memory_config(**overrides: object) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig().model_copy(update=overrides))
def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None:
@ -67,8 +66,7 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
@ -88,8 +86,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
@ -132,8 +129,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=2, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
@ -160,8 +156,7 @@ def test_apply_updates_preserves_source_error() -> None:
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
@ -184,8 +179,7 @@ def test_apply_updates_ignores_empty_source_error() -> None:
]
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
@ -532,7 +526,7 @@ class TestUpdateMemoryStructuredResponse:
with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@ -555,7 +549,7 @@ class TestUpdateMemoryStructuredResponse:
with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(list_content)),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@ -577,7 +571,7 @@ class TestUpdateMemoryStructuredResponse:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@ -602,7 +596,7 @@ class TestUpdateMemoryStructuredResponse:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@ -646,8 +640,7 @@ class TestFactDeduplicationCaseInsensitive:
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
@ -677,8 +670,7 @@ class TestFactDeduplicationCaseInsensitive:
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
@ -704,7 +696,7 @@ class TestReinforcementHint:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@ -729,7 +721,7 @@ class TestReinforcementHint:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
@ -754,7 +746,7 @@ class TestReinforcementHint:
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(AppConfig, "current", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):

View File

@ -72,8 +72,8 @@ class FakeChatModel(BaseChatModel):
def _patch_factory(monkeypatch, app_config: AppConfig, model_class=FakeChatModel):
"""Patch get_app_config, resolve_class, and tracing for isolated unit tests."""
monkeypatch.setattr(factory_module, "get_app_config", lambda: app_config)
"""Patch AppConfig.get, resolve_class, and tracing for isolated unit tests."""
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: model_class)
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
@ -96,7 +96,7 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
def test_raises_when_model_not_found(monkeypatch):
cfg = _make_app_config([_make_model("only-model")])
monkeypatch.setattr(factory_module, "get_app_config", lambda: cfg)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: cfg))
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
with pytest.raises(ValueError, match="ghost-model"):
@ -744,7 +744,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
supports_thinking=True,
when_thinking_enabled=wte,
)
model.extra_body = {"top_k": 20}
model = model.model_copy(update={"extra_body": {"top_k": 20}})
cfg = _make_app_config([model])
_patch_factory(monkeypatch, cfg)
@ -771,7 +771,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
supports_thinking=True,
when_thinking_enabled=wte,
)
model.extra_body = {"top_k": 20}
model = model.model_copy(update={"extra_body": {"top_k": 20}})
cfg = _make_app_config([model])
_patch_factory(monkeypatch, cfg)

View File

@ -3,13 +3,24 @@
import importlib
from types import SimpleNamespace
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
present_file_tool_module = importlib.import_module("deerflow.tools.builtins.present_file_tool")
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _make_runtime(outputs_path: str) -> SimpleNamespace:
return SimpleNamespace(
state={"thread_data": {"outputs_path": outputs_path}},
context={"thread_id": "thread-1"},
context=_make_context("thread-1"),
)

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import patch
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.local.local_sandbox import LocalSandbox
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
from deerflow.sandbox.tools import glob_tool, grep_tool
@ -104,7 +105,7 @@ def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: SimpleNamespace(get_tool_config=lambda name: None)))
result = grep_tool.func(
runtime=runtime,
@ -325,8 +326,8 @@ def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
monkeypatch.setattr(
"deerflow.sandbox.tools.get_app_config",
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
AppConfig, "current",
staticmethod(lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50}))),
)
result = glob_tool.func(

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
@ -617,18 +618,25 @@ def test_apply_cwd_prefix_quotes_path_with_spaces() -> None:
def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None:
"""Bash commands referencing MCP filesystem server paths should be allowed."""
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
from deerflow.config.sandbox_config import SandboxConfig
mock_config = ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=True,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
)
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=mock_config):
def _make_app_config(enabled: bool) -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
extensions=ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=enabled,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
),
)
with patch.object(AppConfig, "current", return_value=_make_app_config(True)):
# Should not raise - MCP filesystem paths are allowed
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA)
@ -637,19 +645,10 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA)
# Disabled servers should not expose paths
disabled_config = ExtensionsConfig(
mcp_servers={
"filesystem": McpServerConfig(
enabled=False,
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/mnt/d/workspace"],
)
}
)
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=disabled_config):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
# Disabled servers should not expose paths
with patch.object(AppConfig, "current", return_value=_make_app_config(False)):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
# ---------- Custom mount path tests ----------
@ -757,7 +756,7 @@ def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
with patch.object(AppConfig, "current", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 2
@ -786,7 +785,7 @@ def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path)
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
with patch.object(AppConfig, "current", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"

View File

@ -2,13 +2,14 @@ from types import SimpleNamespace
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.skills.security_scanner import scan_skill_content
@pytest.mark.anyio
async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch):
config = SimpleNamespace(skill_evolution=SimpleNamespace(moderation_model_name=None))
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)

View File

@ -4,9 +4,20 @@ from types import SimpleNamespace
import anyio
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
skill_manage_module = importlib.import_module("deerflow.tools.skill_manage_tool")
def _make_context(thread_id: str) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
def _skill_content(name: str, description: str = "Demo skill") -> str:
return f"---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"
@ -23,9 +34,7 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
refresh_calls = []
async def _refresh():
@ -34,7 +43,7 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
result = anyio.run(
skill_manage_module.skill_manage_tool.coroutine,
@ -67,9 +76,7 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
async def _refresh():
return None
@ -77,7 +84,7 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
content = _skill_content("demo-skill", "Demo skill") + "\nRepeated: Demo skill\n"
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", content)
@ -107,10 +114,9 @@ def test_skill_manage_rejects_public_skill_patch(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
runtime = SimpleNamespace(context={}, config={"configurable": {}})
runtime = SimpleNamespace(context=_make_context(""), config={"configurable": {}})
with pytest.raises(ValueError, match="built-in skill"):
anyio.run(
@ -131,8 +137,7 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
refresh_calls = []
async def _refresh():
@ -141,7 +146,7 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-sync"}, config={"configurable": {"thread_id": "thread-sync"}})
runtime = SimpleNamespace(context=_make_context("thread-sync"), config={"configurable": {"thread_id": "thread-sync"}})
result = skill_manage_module.skill_manage_tool.func(
runtime=runtime,
action="create",
@ -159,9 +164,7 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
async def _refresh():
return None
@ -169,7 +172,7 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context={"thread_id": "thread-1"}, config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", _skill_content("demo-skill"))
with pytest.raises(ValueError, match="parent-directory traversal|selected support directory"):

View File

@ -6,6 +6,9 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.routers import skills as skills_router
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.skills.manager import get_skill_history_file
from deerflow.skills.types import Skill
@ -43,8 +46,7 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
refresh_calls = []
@ -93,8 +95,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
get_skill_history_file("demo-skill").write_text(
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
encoding="utf-8",
@ -135,8 +136,7 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr("deerflow.skills.manager.get_app_config", lambda: config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
refresh_calls = []
@ -179,9 +179,12 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
refresh_calls.append("refresh")
enabled_state["value"] = False
_app_cfg = AppConfig(sandbox=SandboxConfig(use="test"), extensions=ExtensionsConfig(mcp_servers={}, skills={}))
monkeypatch.setattr("app.gateway.routers.skills.load_skills", _load_skills)
monkeypatch.setattr("app.gateway.routers.skills.get_extensions_config", lambda: SimpleNamespace(mcp_servers={}, skills={}))
monkeypatch.setattr("app.gateway.routers.skills.reload_extensions_config", lambda: None)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: _app_cfg))
monkeypatch.setattr(AppConfig, "init", staticmethod(lambda _cfg: None))
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda: _app_cfg))
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)

Some files were not shown because too many files have changed in this diff Show More