refactor(config): Phase 2 — eliminate AppConfig.current() ambient lookup

Finish Phase 2 of the config refactor: production code no longer calls
AppConfig.current() anywhere. AppConfig now flows as an explicit parameter
down every consumer lane.

Call-site migrations
--------------------
- Memory subsystem (queue/updater/storage): MemoryConfig captured at
  enqueue time so the Timer closure survives the ContextVar boundary.
- Sandbox layer: tools.py, security.py, sandbox_provider.py, local_sandbox_provider,
  aio_sandbox_provider all take app_config explicitly. Module-level
  caching in tools.py's path helpers is removed — pure parameter flow.
- Skills layer: manager.py + loader.py + lead_agent.prompt cache refresh
  all thread app_config; cache worker closes over it.
- Community tools (tavily, jina, firecrawl, exa, ddg, image_search,
  infoquest, aio_sandbox): read runtime.context.app_config.
- Subagents registry: get_subagent_config / list_subagents /
  get_available_subagent_names require app_config.
- Runtime worker: requires RunContext.app_config; no fallback.
- Gateway routers (uploads, skills): add Depends(get_config).
- Channels feishu: uses AppConfig.from_file() (pure) at its sync boundary.
- LangGraph Server bootstrap (make_lead_agent): falls back to
  AppConfig.from_file() — pure load, not ambient lookup.

Context resolution
------------------
- resolve_context(runtime) now raises on non-DeerFlowContext runtime.context.
  Every entry point attaches typed context; dict/None shapes are rejected
  loudly instead of being papered over with an ambient AppConfig lookup.

AppConfig lifecycle
-------------------
- AppConfig.current() kept as a deprecated slot that raises RuntimeError,
  purely so legacy tests that still run `patch.object(AppConfig, "current")`
  don't trip AttributeError at teardown. Production never calls it.
- conftest autouse fixture no longer monkey-patches `current` — it only
  stubs `from_file()` so tests don't need a real config.yaml.

Design refs
-----------
- docs/plans/2026-04-12-config-refactor-plan.md (Phase 2: P2-6..P2-10)
- docs/plans/2026-04-12-config-refactor-design.md §8

All 2338 non-e2e tests pass. Zero AppConfig.current() call sites remain
in backend/packages or backend/app (docstrings in deps.py excepted).
This commit is contained in:
greatmengqi 2026-04-17 11:14:13 +08:00
parent 0e5ff6f431
commit 84dccef230
89 changed files with 4704 additions and 3833 deletions

View File

@ -375,7 +375,9 @@ class FeishuChannel(Channel):
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
try:
sandbox_provider = get_sandbox_provider()
from deerflow.config.app_config import AppConfig
sandbox_provider = get_sandbox_provider(AppConfig.from_file())
sandbox_id = sandbox_provider.acquire(thread_id)
if sandbox_id != "local":
sandbox = sandbox_provider.get(sandbox_id)

View File

@ -67,17 +67,8 @@ class ChannelService:
self._running = False
@classmethod
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
"""Create a ChannelService from the application config.
Pass ``app_config`` explicitly when available (e.g. from Gateway
startup). Falls back to ``AppConfig.current()`` for legacy callers;
that fallback is removed in Phase 2 task P2-10.
"""
if app_config is None:
from deerflow.config.app_config import AppConfig as _AppConfig
app_config = _AppConfig.current()
def from_app_config(cls, app_config: AppConfig) -> ChannelService:
"""Create a ChannelService from an explicit application config."""
channels_config = {}
# extra fields are allowed by AppConfig (extra="allow")
extra = app_config.model_extra or {}
@ -190,12 +181,12 @@ def get_channel_service() -> ChannelService | None:
return _channel_service
async def start_channel_service() -> ChannelService:
async def start_channel_service(app_config: AppConfig) -> ChannelService:
"""Create and start the global ChannelService from app config."""
global _channel_service
if _channel_service is not None:
return _channel_service
_channel_service = ChannelService.from_app_config()
_channel_service = ChannelService.from_app_config(app_config)
await _channel_service.start()
return _channel_service

View File

@ -146,11 +146,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
try:
# app.state.config is the source of truth for Depends(get_config).
# AppConfig.init() mirrors it to the process-global for not-yet-migrated
# callers; both go away in P2-10 once AppConfig.current() is removed.
# ``app.state.config`` is the sole source of truth for
# ``Depends(get_config)``. Consumers that want AppConfig must receive
# it as an explicit parameter; there is no ambient singleton.
app.state.config = AppConfig.from_file()
AppConfig.init(app.state.config)
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
@ -171,7 +170,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service()
channel_service = await start_channel_service(app.state.config)
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")

View File

@ -53,16 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
# app.state.config is populated earlier in lifespan(); thread it into
# every provider that used to reach for AppConfig.current().
config = app.state.config
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
# Initialize persistence engine BEFORE checkpointer so that
# auto-create-database logic runs first (postgres backend).
# Use app.state.config which was populated earlier in lifespan().
config = app.state.config
await init_engine_from_config(config.database)
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
app.state.store = await stack.enter_async_context(make_store(config))
# Initialize repositories — one get_session_factory() call for all.
sf = get_session_factory()

View File

@ -166,12 +166,10 @@ async def update_mcp_configuration(
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
# will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration. Swap app.state.config (new primitive) and
# AppConfig.init() (legacy) so both Depends(get_config) and the not-yet-migrated
# AppConfig.current() callers see the new config.
# Reload the configuration and swap ``app.state.config`` so subsequent
# ``Depends(get_config)`` calls see the refreshed value.
reloaded = AppConfig.from_file()
http_request.app.state.config = reloaded
AppConfig.init(reloaded)
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded.extensions.mcp_servers.items()})
except Exception as e:

View File

@ -115,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
summary="Get Memory Data",
description="Retrieve the current global memory data including user context, history, and facts.",
)
async def get_memory() -> MemoryResponse:
async def get_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Get the current global memory data.
Returns:
@ -149,7 +149,7 @@ async def get_memory() -> MemoryResponse:
}
```
"""
memory_data = get_memory_data(user_id=get_effective_user_id())
memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@ -160,7 +160,7 @@ async def get_memory() -> MemoryResponse:
summary="Reload Memory Data",
description="Reload memory data from the storage file, refreshing the in-memory cache.",
)
async def reload_memory() -> MemoryResponse:
async def reload_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Reload memory data from file.
This forces a reload of the memory data from the storage file,
@ -169,7 +169,7 @@ async def reload_memory() -> MemoryResponse:
Returns:
The reloaded memory data.
"""
memory_data = reload_memory_data(user_id=get_effective_user_id())
memory_data = reload_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@ -180,10 +180,10 @@ async def reload_memory() -> MemoryResponse:
summary="Clear All Memory Data",
description="Delete all saved memory data and reset the memory structure to an empty state.",
)
async def clear_memory() -> MemoryResponse:
async def clear_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Clear all persisted memory data."""
try:
memory_data = clear_memory_data(user_id=get_effective_user_id())
memory_data = clear_memory_data(app_config.memory, user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
@ -197,10 +197,11 @@ async def clear_memory() -> MemoryResponse:
summary="Create Memory Fact",
description="Create a single saved memory fact manually.",
)
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
async def create_memory_fact_endpoint(request: FactCreateRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Create a single fact manually."""
try:
memory_data = create_memory_fact(
app_config.memory,
content=request.content,
category=request.category,
confidence=request.confidence,
@ -221,10 +222,10 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
summary="Delete Memory Fact",
description="Delete a single saved memory fact by its fact id.",
)
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
async def delete_memory_fact_endpoint(fact_id: str, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Delete a single fact from memory by fact id."""
try:
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
memory_data = delete_memory_fact(app_config.memory, fact_id, user_id=get_effective_user_id())
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
@ -240,10 +241,11 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
summary="Patch Memory Fact",
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
)
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Partially update a single fact manually."""
try:
memory_data = update_memory_fact(
app_config.memory,
fact_id=fact_id,
content=request.content,
category=request.category,
@ -267,9 +269,9 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
summary="Export Memory Data",
description="Export the current global memory data as JSON for backup or transfer.",
)
async def export_memory() -> MemoryResponse:
async def export_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Export the current memory data."""
memory_data = get_memory_data(user_id=get_effective_user_id())
memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@ -280,10 +282,10 @@ async def export_memory() -> MemoryResponse:
summary="Import Memory Data",
description="Import and overwrite the current global memory data from a JSON payload.",
)
async def import_memory(request: MemoryResponse) -> MemoryResponse:
async def import_memory(request: MemoryResponse, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Import and persist memory data."""
try:
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
memory_data = import_memory_data(app_config.memory, request.model_dump(), user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
@ -345,7 +347,7 @@ async def get_memory_status(
Combined memory configuration and current data.
"""
config = app_config.memory
memory_data = get_memory_data(user_id=get_effective_user_id())
memory_data = get_memory_data(config, user_id=get_effective_user_id())
return MemoryStatusResponse(
config=MemoryConfigResponse(

View File

@ -10,7 +10,7 @@ from app.gateway.deps import get_config
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
@ -102,9 +102,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.",
)
async def list_skills() -> SkillsListResponse:
async def list_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True)
@ -117,11 +117,11 @@ async def list_skills() -> SkillsListResponse:
summary="Install Skill",
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
)
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
async def install_skill(request: SkillInstallRequest, app_config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
try:
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
result = install_skill_from_archive(skill_file_path)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return SkillInstallResponse(**result)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@ -137,9 +137,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse:
async def list_custom_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
skills = [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True)
@ -147,13 +147,13 @@ async def list_custom_skills() -> SkillsListResponse:
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
async def get_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name, app_config))
except HTTPException:
raise
except Exception as e:
@ -162,14 +162,18 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
async def update_custom_skill(
skill_name: str,
request: CustomSkillUpdateRequest,
app_config: AppConfig = Depends(get_config),
) -> CustomSkillContentResponse:
try:
ensure_custom_skill_is_editable(skill_name)
ensure_custom_skill_is_editable(skill_name, app_config)
validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
scan = await scan_skill_content(app_config, request.content, executable=False, location=f"{skill_name}/SKILL.md")
if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
skill_file = get_custom_skill_dir(skill_name, app_config) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content)
append_history(
@ -183,9 +187,10 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
"new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason},
},
app_config,
)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
await refresh_skills_system_prompt_cache_async(app_config)
return await get_custom_skill(skill_name, app_config)
except HTTPException:
raise
except FileNotFoundError as e:
@ -198,11 +203,11 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
async def delete_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> dict[str, bool]:
try:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
ensure_custom_skill_is_editable(skill_name, app_config)
skill_dir = get_custom_skill_dir(skill_name, app_config)
prev_content = read_custom_skill_content(skill_name, app_config)
append_history(
skill_name,
{
@ -214,9 +219,10 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
app_config,
)
shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return {"success": True}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@ -228,11 +234,11 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
async def get_custom_skill_history(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name))
return CustomSkillHistoryResponse(history=read_history(skill_name, app_config))
except HTTPException:
raise
except Exception as e:
@ -241,11 +247,15 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
async def rollback_custom_skill(
skill_name: str,
request: SkillRollbackRequest,
app_config: AppConfig = Depends(get_config),
) -> CustomSkillContentResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name)
history = read_history(skill_name, app_config)
if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index]
@ -253,8 +263,8 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name)
scan = await scan_skill_content(app_config, target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name, app_config)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = {
"action": "rollback",
@ -267,12 +277,12 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
"scanner": {"decision": scan.decision, "reason": scan.reason},
}
if scan.decision == "block":
append_history(skill_name, history_entry)
append_history(skill_name, history_entry, app_config)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content)
append_history(skill_name, history_entry)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
append_history(skill_name, history_entry, app_config)
await refresh_skills_system_prompt_cache_async(app_config)
return await get_custom_skill(skill_name, app_config)
except HTTPException:
raise
except IndexError:
@ -292,9 +302,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.",
)
async def get_skill(skill_name: str) -> SkillResponse:
async def get_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@ -321,7 +331,7 @@ async def update_skill(
app_config: AppConfig = Depends(get_config),
) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@ -332,26 +342,29 @@ async def update_skill(
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Do not mutate the frozen AppConfig in place. Compose the new skills
# state in a fresh dict, write to disk, and reload AppConfig below so
# every subsequent Depends(get_config) sees the refreshed snapshot.
ext = app_config.extensions
ext.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
updated_skills = {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()}
updated_skills[skill_name] = {"enabled": request.enabled}
config_data = {
"mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()},
"skills": updated_skills,
}
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}")
# Swap both app.state.config and AppConfig._global so Depends(get_config)
# and legacy AppConfig.current() callers see the new config.
# Reload AppConfig and swap ``app.state.config`` so subsequent
# ``Depends(get_config)`` sees the refreshed value.
reloaded = AppConfig.from_file()
http_request.app.state.config = reloaded
AppConfig.init(reloaded)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(reloaded)
skills = load_skills(enabled_only=False)
skills = load_skills(reloaded, enabled_only=False)
updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None:

View File

@ -4,10 +4,12 @@ import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
@ -61,6 +63,7 @@ async def upload_files(
thread_id: str,
request: Request,
files: list[UploadFile] = File(...),
app_config: AppConfig = Depends(get_config),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
if not files:
@ -73,7 +76,7 @@ async def upload_files(
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = []
sandbox_provider = get_sandbox_provider()
sandbox_provider = get_sandbox_provider(app_config)
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)

View File

@ -61,9 +61,9 @@ def _create_summarization_middleware(app_config: AppConfig) -> SummarizationMidd
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
else:
model = create_chat_model(thinking_enabled=False)
model = create_chat_model(thinking_enabled=False, app_config=app_config)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs
@ -288,22 +288,22 @@ def make_lead_agent(
Args:
config: LangGraph ``RunnableConfig`` carrying per-invocation options
(``thinking_enabled``, ``model_name``, ``is_plan_mode``, etc.).
app_config: Resolved application config. When omitted falls back to
:meth:`AppConfig.current`, preserving backward compatibility with
callers that do not thread config explicitly (LangGraph Server
registration path). New callers should pass this parameter.
app_config: Resolved application config. Required for in-process
entry points (DeerFlowClient, Gateway Worker). When omitted we
are being called via ``langgraph.json`` registration and reload
from disk the LangGraph Server bootstrap path has no other
way to thread the value.
"""
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
if app_config is None:
# LangGraph Server registers make_lead_agent via langgraph.json; its
# invocation path only hands us RunnableConfig. Until that registration
# layer owns its own AppConfig, we tolerate the process-global fallback
# here. All other entry points (DeerFlowClient, Gateway Worker) pass
# app_config explicitly. Remove alongside AppConfig.current() in P2-10.
app_config = AppConfig.current()
# LangGraph Server registers ``make_lead_agent`` via ``langgraph.json``
# and hands us only a ``RunnableConfig``. Reload config from disk
# here — it's a pure function, equivalent to the process-global the
# old code path would have read.
app_config = AppConfig.from_file()
cfg = config.get("configurable", {})
@ -363,7 +363,7 @@ def make_lead_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=app_config) + [setup_agent],
middleware=_build_middlewares(app_config, config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
system_prompt=apply_prompt_template(app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
state_schema=ThreadState,
context_schema=DeerFlowContext,
)
@ -374,7 +374,7 @@ def make_lead_agent(
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=app_config),
middleware=_build_middlewares(app_config, config, model_name=model_name, agent_name=agent_name),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
),
state_schema=ThreadState,
context_schema=DeerFlowContext,

View File

@ -20,19 +20,20 @@ _enabled_skills_refresh_version = 0
_enabled_skills_refresh_event = threading.Event()
def _load_enabled_skills_sync() -> list[Skill]:
return list(load_skills(enabled_only=True))
def _load_enabled_skills_sync(app_config: AppConfig | None) -> list[Skill]:
return list(load_skills(app_config, enabled_only=True))
def _start_enabled_skills_refresh_thread() -> None:
def _start_enabled_skills_refresh_thread(app_config: AppConfig | None) -> None:
threading.Thread(
target=_refresh_enabled_skills_cache_worker,
args=(app_config,),
name="deerflow-enabled-skills-loader",
daemon=True,
).start()
def _refresh_enabled_skills_cache_worker() -> None:
def _refresh_enabled_skills_cache_worker(app_config: AppConfig | None) -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active
while True:
@ -40,8 +41,8 @@ def _refresh_enabled_skills_cache_worker() -> None:
target_version = _enabled_skills_refresh_version
try:
skills = _load_enabled_skills_sync()
except Exception:
skills = _load_enabled_skills_sync(app_config)
except (OSError, ImportError):
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
@ -57,7 +58,7 @@ def _refresh_enabled_skills_cache_worker() -> None:
_enabled_skills_cache = None
def _ensure_enabled_skills_cache() -> threading.Event:
def _ensure_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
global _enabled_skills_refresh_active
with _enabled_skills_lock:
@ -69,11 +70,11 @@ def _ensure_enabled_skills_cache() -> threading.Event:
_enabled_skills_refresh_active = True
_enabled_skills_refresh_event.clear()
_start_enabled_skills_refresh_thread()
_start_enabled_skills_refresh_thread(app_config)
return _enabled_skills_refresh_event
def _invalidate_enabled_skills_cache() -> threading.Event:
def _invalidate_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear()
@ -85,30 +86,30 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
return _enabled_skills_refresh_event
_enabled_skills_refresh_active = True
_start_enabled_skills_refresh_thread()
_start_enabled_skills_refresh_thread(app_config)
return _enabled_skills_refresh_event
def prime_enabled_skills_cache() -> None:
_ensure_enabled_skills_cache()
def prime_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
_ensure_enabled_skills_cache(app_config)
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds):
def warm_enabled_skills_cache(app_config: AppConfig | None = None, timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
if _ensure_enabled_skills_cache(app_config).wait(timeout=timeout_seconds):
return True
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
return False
def _get_enabled_skills():
def _get_enabled_skills(app_config: AppConfig | None = None):
with _enabled_skills_lock:
cached = _enabled_skills_cache
if cached is not None:
return list(cached)
_ensure_enabled_skills_cache()
_ensure_enabled_skills_cache(app_config)
return []
@ -116,12 +117,12 @@ def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]"
def clear_skills_system_prompt_cache() -> None:
_invalidate_enabled_skills_cache()
def clear_skills_system_prompt_cache(app_config: AppConfig | None = None) -> None:
_invalidate_enabled_skills_cache(app_config)
async def refresh_skills_system_prompt_cache_async() -> None:
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
async def refresh_skills_system_prompt_cache_async(app_config: AppConfig | None = None) -> None:
await asyncio.to_thread(_invalidate_enabled_skills_cache(app_config).wait)
def _reset_skills_system_prompt_cache_state() -> None:
@ -135,10 +136,10 @@ def _reset_skills_system_prompt_cache_state() -> None:
_enabled_skills_refresh_event.clear()
def _refresh_enabled_skills_cache() -> None:
def _refresh_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
"""Backward-compatible test helper for direct synchronous reload."""
try:
skills = _load_enabled_skills_sync()
skills = _load_enabled_skills_sync(app_config)
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
@ -165,17 +166,18 @@ Skip simple one-off tasks.
"""
def _build_subagent_section(max_concurrent: int) -> str:
def _build_subagent_section(max_concurrent: int, app_config: AppConfig) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
Args:
max_concurrent: Maximum number of concurrent subagent calls allowed per response.
app_config: Application config used to gate bash availability.
Returns:
Formatted subagent section string.
"""
n = max_concurrent
bash_available = "bash" in get_available_subagent_names()
bash_available = "bash" in get_available_subagent_names(app_config)
available_subagents = (
"- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n- **bash**: For command execution (git, build, test, deploy operations)"
if bash_available
@ -508,36 +510,34 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
"""
def _get_memory_context(agent_name: str | None = None) -> str:
def _get_memory_context(app_config: AppConfig, agent_name: str | None = None) -> str:
"""Get memory context for injection into system prompt.
Args:
agent_name: If provided, loads per-agent memory. If None, loads global memory.
Returns:
Formatted memory context string wrapped in XML tags, or empty string if disabled.
Returns an empty string when memory is disabled or the stored memory file
cannot be read/parsed. A corrupt memory.json degrades the prompt to
no-memory; it never kills the agent.
"""
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.runtime.user_context import get_effective_user_id
memory_config = app_config.memory
if not memory_config.enabled or not memory_config.injection_enabled:
return ""
try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.runtime.user_context import get_effective_user_id
memory_data = get_memory_data(memory_config, agent_name, user_id=get_effective_user_id())
except (OSError, ValueError, UnicodeDecodeError):
logger.exception("Failed to load memory data for prompt injection")
return ""
config = AppConfig.current().memory
if not config.enabled or not config.injection_enabled:
return ""
memory_content = format_memory_for_injection(memory_data, max_tokens=memory_config.max_injection_tokens)
if not memory_content.strip():
return ""
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
if not memory_content.strip():
return ""
return f"""<memory>
return f"""<memory>
{memory_content}
</memory>
"""
except Exception as e:
logger.error("Failed to load memory context: %s", e)
return ""
@lru_cache(maxsize=32)
@ -572,17 +572,12 @@ You have access to skills that provide optimized workflows for specific tasks. E
</skill_system>"""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
def get_skills_prompt_section(app_config: AppConfig, available_skills: set[str] | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills()
skills = _get_enabled_skills(app_config)
try:
config = AppConfig.current()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
container_base_path = "/mnt/skills"
skill_evolution_enabled = False
container_base_path = app_config.skills.container_path
skill_evolution_enabled = app_config.skill_evolution.enabled
if not skills and not skill_evolution_enabled:
return ""
@ -606,7 +601,7 @@ def get_agent_soul(agent_name: str | None) -> str:
return ""
def get_deferred_tools_prompt_section() -> str:
def get_deferred_tools_prompt_section(app_config: AppConfig) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists
@ -615,10 +610,7 @@ def get_deferred_tools_prompt_section() -> str:
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
try:
if not AppConfig.current().tool_search.enabled:
return ""
except Exception:
if not app_config.tool_search.enabled:
return ""
registry = get_deferred_registry()
@ -629,13 +621,9 @@ def get_deferred_tools_prompt_section() -> str:
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
def _build_acp_section() -> str:
def _build_acp_section(app_config: AppConfig) -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured."""
try:
agents = AppConfig.current().acp_agents
if not agents:
return ""
except Exception:
if not app_config.acp_agents:
return ""
return (
@ -647,13 +635,9 @@ def _build_acp_section() -> str:
)
def _build_custom_mounts_section() -> str:
def _build_custom_mounts_section(app_config: AppConfig) -> str:
"""Build a prompt section for explicitly configured sandbox mounts."""
try:
mounts = AppConfig.current().sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
mounts = app_config.sandbox.mounts or []
if not mounts:
return ""
@ -667,13 +651,20 @@ def _build_custom_mounts_section() -> str:
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
def apply_prompt_template(
app_config: AppConfig,
subagent_enabled: bool = False,
max_concurrent_subagents: int = 3,
*,
agent_name: str | None = None,
available_skills: set[str] | None = None,
) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name)
memory_context = _get_memory_context(app_config, agent_name)
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
subagent_section = _build_subagent_section(n, app_config) if subagent_enabled else ""
# Add subagent reminder to critical_reminders if enabled
subagent_reminder = (
@ -694,14 +685,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
)
# Get skills section
skills_section = get_skills_prompt_section(available_skills)
skills_section = get_skills_prompt_section(app_config, available_skills)
# Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section()
deferred_tools_section = get_deferred_tools_prompt_section(app_config)
# Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section()
custom_mounts_section = _build_custom_mounts_section()
acp_section = _build_acp_section(app_config)
custom_mounts_section = _build_custom_mounts_section(app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory

View File

@ -12,6 +12,12 @@ from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
# Module-level config pointer set by the middleware that owns the queue.
# The queue runs on a background Timer thread where ``Runtime`` and FastAPI
# request context are not accessible; the enqueuer (which does have runtime
# context) is responsible for plumbing ``AppConfig`` through ``add()``.
@dataclass
class ConversationContext:
"""Context for a conversation to be processed for memory update."""
@ -31,10 +37,21 @@ class MemoryUpdateQueue:
This queue collects conversation contexts and processes them after
a configurable debounce period. Multiple conversations received within
the debounce window are batched together.
The queue captures an ``AppConfig`` reference at construction time and
reuses it for the MemoryUpdater it spawns. Callers must construct a
fresh queue when the config changes rather than reaching into a global.
"""
def __init__(self):
"""Initialize the memory update queue."""
def __init__(self, app_config: AppConfig):
"""Initialize the memory update queue.
Args:
app_config: Application config. The queue reads its own
``memory`` section for debounce timing and hands the full
config to :class:`MemoryUpdater`.
"""
self._app_config = app_config
self._queue: list[ConversationContext] = []
self._lock = threading.Lock()
self._timer: threading.Timer | None = None
@ -49,19 +66,8 @@ class MemoryUpdateQueue:
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
Args:
thread_id: The thread ID.
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
survives the threading.Timer boundary (ContextVar does not propagate across
raw threads).
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = AppConfig.current().memory
"""Add a conversation to the update queue."""
config = self._app_config.memory
if not config.enabled:
return
@ -93,7 +99,7 @@ class MemoryUpdateQueue:
def _reset_timer(self) -> None:
"""Reset the debounce timer."""
config = AppConfig.current().memory
config = self._app_config.memory
# Cancel existing timer if any
if self._timer is not None:
@ -131,7 +137,7 @@ class MemoryUpdateQueue:
logger.info("Processing %d queued memory updates", len(contexts_to_process))
try:
updater = MemoryUpdater()
updater = MemoryUpdater(self._app_config)
for context in contexts_to_process:
try:
@ -196,31 +202,35 @@ class MemoryUpdateQueue:
return self._processing
# Global singleton instance
_memory_queue: MemoryUpdateQueue | None = None
# Queues keyed by ``id(AppConfig)`` so tests and multi-client setups with
# distinct configs do not share a debounce queue.
_memory_queues: dict[int, MemoryUpdateQueue] = {}
_queue_lock = threading.Lock()
def get_memory_queue() -> MemoryUpdateQueue:
"""Get the global memory update queue singleton.
Returns:
The memory update queue instance.
"""
global _memory_queue
def get_memory_queue(app_config: AppConfig) -> MemoryUpdateQueue:
"""Get or create the memory update queue for the given app config."""
key = id(app_config)
with _queue_lock:
if _memory_queue is None:
_memory_queue = MemoryUpdateQueue()
return _memory_queue
queue = _memory_queues.get(key)
if queue is None:
queue = MemoryUpdateQueue(app_config)
_memory_queues[key] = queue
return queue
def reset_memory_queue() -> None:
"""Reset the global memory queue.
def reset_memory_queue(app_config: AppConfig | None = None) -> None:
"""Reset memory queue(s).
This is useful for testing.
Pass an ``app_config`` to reset only its queue, or omit to reset all
(useful at test teardown).
"""
global _memory_queue
with _queue_lock:
if _memory_queue is not None:
_memory_queue.clear()
_memory_queue = None
if app_config is not None:
queue = _memory_queues.pop(id(app_config), None)
if queue is not None:
queue.clear()
return
for queue in _memory_queues.values():
queue.clear()
_memory_queues.clear()

View File

@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
@ -61,8 +61,15 @@ class MemoryStorage(abc.ABC):
class FileMemoryStorage(MemoryStorage):
"""File-based memory storage provider."""
def __init__(self):
"""Initialize the file memory storage."""
def __init__(self, memory_config: MemoryConfig):
"""Initialize the file memory storage.
Args:
memory_config: Memory configuration (storage_path etc.). Stored on
the instance so per-request lookups don't need to reach for
ambient state.
"""
self._memory_config = memory_config
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
# Value: (memory_data, file_mtime)
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
@ -80,11 +87,11 @@ class FileMemoryStorage(MemoryStorage):
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
"""Get the path to the memory file."""
config = self._memory_config
if user_id is not None:
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().user_agent_memory_file(user_id, agent_name)
config = AppConfig.current().memory
if config.storage_path and Path(config.storage_path).is_absolute():
return Path(config.storage_path)
return get_paths().user_memory_file(user_id)
@ -92,7 +99,6 @@ class FileMemoryStorage(MemoryStorage):
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name)
config = AppConfig.current().memory
if config.storage_path:
p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p
@ -174,23 +180,31 @@ class FileMemoryStorage(MemoryStorage):
return False
_storage_instance: MemoryStorage | None = None
# Instances keyed by (storage_class_path, id(memory_config)) so tests can
# construct isolated storages and multi-client setups with different configs
# don't collide on a single process-wide singleton.
_storage_instances: dict[tuple[str, int], MemoryStorage] = {}
_storage_lock = threading.Lock()
def get_memory_storage() -> MemoryStorage:
"""Get the configured memory storage instance."""
global _storage_instance
if _storage_instance is not None:
return _storage_instance
def get_memory_storage(memory_config: MemoryConfig) -> MemoryStorage:
"""Get the configured memory storage instance.
Caches one instance per ``(storage_class, memory_config)`` pair. In
single-config deployments this collapses to one instance; in multi-client
or test scenarios each config gets its own storage.
"""
key = (memory_config.storage_class, id(memory_config))
existing = _storage_instances.get(key)
if existing is not None:
return existing
with _storage_lock:
if _storage_instance is not None:
return _storage_instance
config = AppConfig.current().memory
storage_class_path = config.storage_class
existing = _storage_instances.get(key)
if existing is not None:
return existing
storage_class_path = memory_config.storage_class
try:
module_path, class_name = storage_class_path.rsplit(".", 1)
import importlib
@ -204,13 +218,14 @@ def get_memory_storage() -> MemoryStorage:
if not issubclass(storage_class, MemoryStorage):
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
_storage_instance = storage_class()
instance = storage_class(memory_config)
except Exception as e:
logger.error(
"Failed to load memory storage %s, falling back to FileMemoryStorage: %s",
storage_class_path,
e,
)
_storage_instance = FileMemoryStorage()
instance = FileMemoryStorage(memory_config)
return _storage_instance
_storage_instances[key] = instance
return instance

View File

@ -17,6 +17,7 @@ from deerflow.agents.memory.storage import (
utc_now_iso_z,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@ -27,45 +28,33 @@ def _create_empty_memory() -> dict[str, Any]:
return create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Backward-compatible wrapper around the configured memory storage save path."""
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
def _save_memory_to_file(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Save via the configured memory storage."""
return get_memory_storage(memory_config).save(memory_data, agent_name, user_id=user_id)
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
def get_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Get the current memory data via storage provider."""
return get_memory_storage().load(agent_name, user_id=user_id)
return get_memory_storage(memory_config).load(agent_name, user_id=user_id)
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
def reload_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Reload memory data via storage provider."""
return get_memory_storage().reload(agent_name, user_id=user_id)
return get_memory_storage(memory_config).reload(agent_name, user_id=user_id)
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider.
Args:
memory_data: Full memory payload to persist.
agent_name: If provided, imports into per-agent memory.
user_id: If provided, scopes memory to a specific user.
Returns:
The saved memory data after storage normalization.
Raises:
OSError: If persisting the imported memory fails.
"""
storage = get_memory_storage()
def import_memory_data(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider."""
storage = get_memory_storage(memory_config)
if not storage.save(memory_data, agent_name, user_id=user_id):
raise OSError("Failed to save imported memory data")
return storage.load(agent_name, user_id=user_id)
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
def clear_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Clear all stored memory data and persist an empty structure."""
cleared_memory = create_empty_memory()
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
if not _save_memory_to_file(memory_config, cleared_memory, agent_name, user_id=user_id):
raise OSError("Failed to save cleared memory data")
return cleared_memory
@ -78,6 +67,7 @@ def _validate_confidence(confidence: float) -> float:
def create_memory_fact(
memory_config: MemoryConfig,
content: str,
category: str = "context",
confidence: float = 0.5,
@ -93,7 +83,7 @@ def create_memory_fact(
normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence)
now = utc_now_iso_z()
memory_data = get_memory_data(agent_name, user_id=user_id)
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", []))
facts.append(
@ -108,15 +98,15 @@ def create_memory_fact(
)
updated_memory["facts"] = facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError("Failed to save memory data after creating fact")
return updated_memory
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
def delete_memory_fact(memory_config: MemoryConfig, fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Delete a fact by its id and persist the updated memory data."""
memory_data = get_memory_data(agent_name, user_id=user_id)
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
facts = memory_data.get("facts", [])
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
if len(updated_facts) == len(facts):
@ -125,13 +115,14 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id:
updated_memory = dict(memory_data)
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
return updated_memory
def update_memory_fact(
memory_config: MemoryConfig,
fact_id: str,
content: str | None = None,
category: str | None = None,
@ -141,7 +132,7 @@ def update_memory_fact(
user_id: str | None = None,
) -> dict[str, Any]:
"""Update an existing fact and persist the updated memory data."""
memory_data = get_memory_data(agent_name, user_id=user_id)
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
updated_memory = dict(memory_data)
updated_facts: list[dict[str, Any]] = []
found = False
@ -168,7 +159,7 @@ def update_memory_fact(
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
return updated_memory
@ -260,19 +251,25 @@ def _fact_content_key(content: Any) -> str | None:
class MemoryUpdater:
"""Updates memory using LLM based on conversation context."""
def __init__(self, model_name: str | None = None):
def __init__(self, app_config: AppConfig, model_name: str | None = None):
"""Initialize the memory updater.
Args:
app_config: Application config (the updater needs both ``memory``
section for behavior and the full config for ``create_chat_model``).
model_name: Optional model name to use. If None, uses config or default.
"""
self._app_config = app_config
self._model_name = model_name
@property
def _memory_config(self) -> MemoryConfig:
return self._app_config.memory
def _get_model(self):
"""Get the model for memory updates."""
config = AppConfig.current().memory
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
model_name = self._model_name or self._memory_config.model_name
return create_chat_model(name=model_name, thinking_enabled=False, app_config=self._app_config)
def update_memory(
self,
@ -296,7 +293,7 @@ class MemoryUpdater:
Returns:
True if update was successful, False otherwise.
"""
config = AppConfig.current().memory
config = self._memory_config
if not config.enabled:
return False
@ -305,7 +302,7 @@ class MemoryUpdater:
try:
# Get current memory
current_memory = get_memory_data(agent_name, user_id=user_id)
current_memory = get_memory_data(config, agent_name, user_id=user_id)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
@ -360,7 +357,7 @@ class MemoryUpdater:
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
return get_memory_storage(config).save(updated_memory, agent_name, user_id=user_id)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
@ -385,7 +382,7 @@ class MemoryUpdater:
Returns:
Updated memory data.
"""
config = AppConfig.current().memory
config = self._memory_config
now = utc_now_iso_z()
# Update user sections

View File

@ -236,7 +236,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# threading.Timer fires on a different thread where ContextVar values are not
# propagated, so we must store user_id explicitly in ConversationContext.
user_id = get_effective_user_id()
queue = get_memory_queue()
queue = get_memory_queue(runtime.context.app_config)
queue.add(
thread_id=thread_id,
messages=filtered_messages,

View File

@ -8,6 +8,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.title_config import TitleConfig
from deerflow.models import create_chat_model
@ -120,8 +121,9 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
_, user_msg = self._build_title_prompt(state, title_config)
return {"title": self._fallback_title(user_msg, title_config)}
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
async def _agenerate_title_result(self, state: TitleMiddlewareState, app_config: AppConfig) -> dict | None:
"""Generate a title asynchronously and fall back locally on failure."""
title_config = app_config.title
if not self._should_generate_title(state, title_config):
return None
@ -129,9 +131,9 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
try:
if title_config.model_name:
model = create_chat_model(name=title_config.model_name, thinking_enabled=False)
model = create_chat_model(name=title_config.model_name, thinking_enabled=False, app_config=app_config)
else:
model = create_chat_model(thinking_enabled=False)
model = create_chat_model(thinking_enabled=False, app_config=app_config)
response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content, title_config)
if title:
@ -146,4 +148,4 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
@override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return await self._agenerate_title_result(state, runtime.context.app_config.title)
return await self._agenerate_title_result(state, runtime.context.app_config)

View File

@ -38,7 +38,7 @@ from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model
from deerflow.runtime.user_context import get_effective_user_id
@ -151,16 +151,15 @@ class DeerFlowClient:
# Constructor-captured config: the client owns its AppConfig for its lifetime.
# Multiple clients with different configs do not contend.
#
# Priority: explicit ``config=`` > explicit ``config_path=`` > AppConfig.current().
# The third tier preserves backward compatibility with callers that relied on
# the process-global (tests via the conftest autouse fixture). After P2-10
# removes AppConfig.current(), this fallback will require an explicit choice.
# Priority: explicit ``config=`` > explicit ``config_path=`` > ``AppConfig.from_file()``
# with default path resolution. There is no ambient global fallback; if
# config.yaml cannot be located, ``from_file`` raises loudly.
if config is not None:
self._app_config = config
elif config_path is not None:
self._app_config = AppConfig.from_file(config_path)
else:
self._app_config = AppConfig.current()
self._app_config = AppConfig.from_file()
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@ -254,10 +253,11 @@ class DeerFlowClient:
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(self._app_config, config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"system_prompt": apply_prompt_template(
self._app_config,
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
@ -269,7 +269,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(self._app_config)
if checkpointer is not None:
kwargs["checkpointer"] = checkpointer
@ -277,12 +277,11 @@ class DeerFlowClient:
self._agent_config_key = key
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
@staticmethod
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
"""Lazy import to avoid circular dependency at module level."""
from deerflow.tools import get_available_tools
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=self._app_config)
@staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]:
@ -403,7 +402,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(self._app_config)
thread_info_map = {}
@ -458,7 +457,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(self._app_config)
config = {"configurable": {"thread_id": thread_id}}
checkpoints = []
@ -782,7 +781,7 @@ class DeerFlowClient:
"category": s.category,
"enabled": s.enabled,
}
for s in load_skills(enabled_only=enabled_only)
for s in load_skills(self._app_config, enabled_only=enabled_only)
]
}
@ -794,19 +793,19 @@ class DeerFlowClient:
"""
from deerflow.agents.memory.updater import get_memory_data
return get_memory_data(user_id=get_effective_user_id())
return get_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def export_memory(self) -> dict:
"""Export current memory data for backup or transfer."""
from deerflow.agents.memory.updater import get_memory_data
return get_memory_data(user_id=get_effective_user_id())
return get_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def import_memory(self, memory_data: dict) -> dict:
"""Import and persist full memory data."""
from deerflow.agents.memory.updater import import_memory_data
return import_memory_data(memory_data, user_id=get_effective_user_id())
return import_memory_data(self._app_config.memory, memory_data, user_id=get_effective_user_id())
def get_model(self, name: str) -> dict | None:
"""Get a specific model's configuration by name.
@ -894,7 +893,7 @@ class DeerFlowClient:
"""
from deerflow.skills.loader import load_skills
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
skill = next((s for s in load_skills(self._app_config, enabled_only=False) if s.name == name), None)
if skill is None:
return None
return {
@ -921,7 +920,7 @@ class DeerFlowClient:
"""
from deerflow.skills.loader import load_skills
skills = load_skills(enabled_only=False)
skills = load_skills(self._app_config, enabled_only=False)
skill = next((s for s in skills if s.name == name), None)
if skill is None:
raise ValueError(f"Skill '{name}' not found")
@ -930,12 +929,16 @@ class DeerFlowClient:
if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
# Do not mutate self._app_config (frozen value). Compose the new
# skills state in a fresh dict, write it to disk, and let _reload_config()
# below rebuild AppConfig from the updated file.
ext = self._app_config.extensions
ext.skills[name] = SkillStateConfig(enabled=enabled)
new_skills = {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()}
new_skills[name] = {"enabled": enabled}
config_data = {
"mcpServers": {n: s.model_dump() for n, s in ext.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()},
"skills": new_skills,
}
self._atomic_write_json(config_path, config_data)
@ -944,7 +947,7 @@ class DeerFlowClient:
self._agent_config_key = None
self._reload_config()
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
updated = next((s for s in load_skills(self._app_config, enabled_only=False) if s.name == name), None)
if updated is None:
raise RuntimeError(f"Skill '{name}' disappeared after update")
return {
@ -982,25 +985,25 @@ class DeerFlowClient:
"""
from deerflow.agents.memory.updater import reload_memory_data
return reload_memory_data(user_id=get_effective_user_id())
return reload_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def clear_memory(self) -> dict:
"""Clear all persisted memory data."""
from deerflow.agents.memory.updater import clear_memory_data
return clear_memory_data(user_id=get_effective_user_id())
return clear_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
"""Create a single fact manually."""
from deerflow.agents.memory.updater import create_memory_fact
return create_memory_fact(content=content, category=category, confidence=confidence)
return create_memory_fact(self._app_config.memory, content=content, category=category, confidence=confidence)
def delete_memory_fact(self, fact_id: str) -> dict:
"""Delete a single fact from memory by fact id."""
from deerflow.agents.memory.updater import delete_memory_fact
return delete_memory_fact(fact_id)
return delete_memory_fact(self._app_config.memory, fact_id)
def update_memory_fact(
self,
@ -1013,6 +1016,7 @@ class DeerFlowClient:
from deerflow.agents.memory.updater import update_memory_fact
return update_memory_fact(
self._app_config.memory,
fact_id=fact_id,
content=content,
category=category,

View File

@ -90,7 +90,8 @@ class AioSandboxProvider(SandboxProvider):
API_KEY: $MY_API_KEY
"""
def __init__(self):
def __init__(self, app_config: "AppConfig"):
self._app_config = app_config
self._lock = threading.Lock()
self._sandboxes: dict[str, AioSandbox] = {} # sandbox_id -> AioSandbox instance
self._sandbox_infos: dict[str, SandboxInfo] = {} # sandbox_id -> SandboxInfo (for destroy)
@ -149,8 +150,7 @@ class AioSandboxProvider(SandboxProvider):
def _load_config(self) -> dict:
"""Load sandbox configuration from app config."""
config = AppConfig.current()
sandbox_config = config.sandbox
sandbox_config = self._app_config.sandbox
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None)
@ -273,17 +273,15 @@ class AioSandboxProvider(SandboxProvider):
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
]
@staticmethod
def _get_skills_mount() -> tuple[str, str, bool] | None:
def _get_skills_mount(self) -> tuple[str, str, bool] | None:
"""Get the skills directory mount configuration.
Mount source uses DEER_FLOW_HOST_SKILLS_PATH when running inside Docker (DooD)
so the host Docker daemon can resolve the path.
"""
try:
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path
skills_path = self._app_config.skills.get_skills_path()
container_path = self._app_config.skills.container_path
if skills_path.exists():
# When running inside Docker with DooD, use host-side skills path.

View File

@ -5,9 +5,9 @@ Web Search Tool - Search the web using DuckDuckGo (no API key required).
import json
import logging
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
logger = logging.getLogger(__name__)
@ -55,6 +55,7 @@ def _search_text(
@tool("web_search", parse_docstring=True)
def web_search_tool(
query: str,
runtime: ToolRuntime,
max_results: int = 5,
) -> str:
"""Search the web for information. Use this tool to find current information, news, articles, and facts from the internet.
@ -63,11 +64,11 @@ def web_search_tool(
query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of results to return. Default is 5.
"""
config = AppConfig.current().get_tool_config("web_search")
tool_config = resolve_context(runtime).app_config.get_tool_config("web_search")
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = tool_config.model_extra.get("max_results", max_results)
results = _search_text(
query=query,

View File

@ -1,37 +1,39 @@
import json
from exa_py import Exa
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_exa_client(tool_name: str = "web_search") -> Exa:
config = AppConfig.current().get_tool_config(tool_name)
def _get_exa_client(app_config: AppConfig, tool_name: str = "web_search") -> Exa:
tool_config = app_config.get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = tool_config.model_extra.get("api_key")
return Exa(api_key=api_key)
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
try:
config = AppConfig.current().get_tool_config("web_search")
app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5
search_type = "auto"
contents_max_characters = 1000
if config is not None:
max_results = config.model_extra.get("max_results", max_results)
search_type = config.model_extra.get("search_type", search_type)
contents_max_characters = config.model_extra.get("contents_max_characters", contents_max_characters)
if tool_config is not None:
max_results = tool_config.model_extra.get("max_results", max_results)
search_type = tool_config.model_extra.get("search_type", search_type)
contents_max_characters = tool_config.model_extra.get("contents_max_characters", contents_max_characters)
client = _get_exa_client()
client = _get_exa_client(app_config)
res = client.search(
query,
type=search_type,
@ -54,7 +56,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@ -65,7 +67,7 @@ def web_fetch_tool(url: str) -> str:
url: The URL to fetch the contents of.
"""
try:
client = _get_exa_client("web_fetch")
client = _get_exa_client(resolve_context(runtime).app_config, "web_fetch")
res = client.get_contents([url], text={"max_characters": 4096})
if res.results:

View File

@ -1,33 +1,35 @@
import json
from firecrawl import FirecrawlApp
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
config = AppConfig.current().get_tool_config(tool_name)
def _get_firecrawl_client(app_config: AppConfig, tool_name: str = "web_search") -> FirecrawlApp:
tool_config = app_config.get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = tool_config.model_extra.get("api_key")
return FirecrawlApp(api_key=api_key) # type: ignore[arg-type]
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
try:
config = AppConfig.current().get_tool_config("web_search")
app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5
if config is not None:
max_results = config.model_extra.get("max_results", max_results)
if tool_config is not None:
max_results = tool_config.model_extra.get("max_results", max_results)
client = _get_firecrawl_client("web_search")
client = _get_firecrawl_client(app_config, "web_search")
result = client.search(query, limit=max_results)
# result.web contains list of SearchResultWeb objects
@ -47,7 +49,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@ -58,7 +60,8 @@ def web_fetch_tool(url: str) -> str:
url: The URL to fetch the contents of.
"""
try:
client = _get_firecrawl_client("web_fetch")
app_config = resolve_context(runtime).app_config
client = _get_firecrawl_client(app_config, "web_fetch")
result = client.scrape(url, formats=["markdown"])
markdown_content = result.markdown or ""

View File

@ -5,9 +5,9 @@ Image Search Tool - Search images using DuckDuckGo for reference in image genera
import json
import logging
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
logger = logging.getLogger(__name__)
@ -77,6 +77,7 @@ def _search_images(
@tool("image_search", parse_docstring=True)
def image_search_tool(
query: str,
runtime: ToolRuntime,
max_results: int = 5,
size: str | None = None,
type_image: str | None = None,
@ -99,11 +100,11 @@ def image_search_tool(
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
"""
config = AppConfig.current().get_tool_config("image_search")
tool_config = resolve_context(runtime).app_config.get_tool_config("image_search")
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = tool_config.model_extra.get("max_results", max_results)
results = _search_images(
query=query,

View File

@ -1,6 +1,7 @@
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
from deerflow.utils.readability import ReadabilityExtractor
from .infoquest_client import InfoQuestClient
@ -8,13 +9,13 @@ from .infoquest_client import InfoQuestClient
readability_extractor = ReadabilityExtractor()
def _get_infoquest_client() -> InfoQuestClient:
search_config = AppConfig.current().get_tool_config("web_search")
def _get_infoquest_client(app_config: AppConfig) -> InfoQuestClient:
search_config = app_config.get_tool_config("web_search")
search_time_range = -1
if search_config is not None and "search_time_range" in search_config.model_extra:
search_time_range = search_config.model_extra.get("search_time_range")
fetch_config = AppConfig.current().get_tool_config("web_fetch")
fetch_config = app_config.get_tool_config("web_fetch")
fetch_time = -1
if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
fetch_time = fetch_config.model_extra.get("fetch_time")
@ -25,7 +26,7 @@ def _get_infoquest_client() -> InfoQuestClient:
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
image_search_config = AppConfig.current().get_tool_config("image_search")
image_search_config = app_config.get_tool_config("image_search")
image_search_time_range = -1
if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra:
image_search_time_range = image_search_config.model_extra.get("image_search_time_range")
@ -44,19 +45,18 @@ def _get_infoquest_client() -> InfoQuestClient:
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
client = _get_infoquest_client()
client = _get_infoquest_client(resolve_context(runtime).app_config)
return client.web_search(query)
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@ -66,7 +66,7 @@ def web_fetch_tool(url: str) -> str:
Args:
url: The URL to fetch the contents of.
"""
client = _get_infoquest_client()
client = _get_infoquest_client(resolve_context(runtime).app_config)
result = client.fetch(url)
if result.startswith("Error: "):
return result
@ -75,7 +75,7 @@ def web_fetch_tool(url: str) -> str:
@tool("image_search", parse_docstring=True)
def image_search_tool(query: str) -> str:
def image_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search for images online. Use this tool BEFORE image generation to find reference images for characters, portraits, objects, scenes, or any content requiring visual accuracy.
**When to use:**
@ -89,5 +89,5 @@ def image_search_tool(query: str) -> str:
Args:
query: The query to search for images.
"""
client = _get_infoquest_client()
client = _get_infoquest_client(resolve_context(runtime).app_config)
return client.image_search(query)

View File

@ -1,14 +1,14 @@
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
from deerflow.utils.readability import ReadabilityExtractor
readability_extractor = ReadabilityExtractor()
@tool("web_fetch", parse_docstring=True)
async def web_fetch_tool(url: str) -> str:
async def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@ -20,9 +20,9 @@ async def web_fetch_tool(url: str) -> str:
"""
jina_client = JinaClient()
timeout = 10
config = AppConfig.current().get_tool_config("web_fetch")
if config is not None and "timeout" in config.model_extra:
timeout = config.model_extra.get("timeout")
tool_config = resolve_context(runtime).app_config.get_tool_config("web_fetch")
if tool_config is not None and "timeout" in tool_config.model_extra:
timeout = tool_config.model_extra.get("timeout")
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content

View File

@ -1,32 +1,34 @@
import json
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from tavily import TavilyClient
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_tavily_client() -> TavilyClient:
config = AppConfig.current().get_tool_config("web_search")
def _get_tavily_client(app_config: AppConfig) -> TavilyClient:
tool_config = app_config.get_tool_config("web_search")
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = tool_config.model_extra.get("api_key")
return TavilyClient(api_key=api_key)
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
config = AppConfig.current().get_tool_config("web_search")
app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results")
if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = tool_config.model_extra.get("max_results")
client = _get_tavily_client()
client = _get_tavily_client(app_config)
res = client.search(query, max_results=max_results)
normalized_results = [
{
@ -41,7 +43,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@ -51,7 +53,8 @@ def web_fetch_tool(url: str) -> str:
Args:
url: The URL to fetch the contents of.
"""
client = _get_tavily_client()
app_config = resolve_context(runtime).app_config
client = _get_tavily_client(app_config)
res = client.extract([url])
if "failed_results" in res and len(res["failed_results"]) > 0:
return f"Error: {res['failed_results'][0]['error']}"

View File

@ -2,9 +2,8 @@ from __future__ import annotations
import logging
import os
from contextvars import ContextVar, Token
from pathlib import Path
from typing import Any, ClassVar, Self
from typing import Any, Self
import yaml
from dotenv import load_dotenv
@ -222,54 +221,21 @@ class AppConfig(BaseModel):
"""
return next((group for group in self.tool_groups if group.name == name), None)
# -- Lifecycle (process-global + per-context override) --
# AppConfig is a pure value object: construct with ``from_file()``, pass around.
# Composition roots that hold the singleton:
# - Gateway: ``app.state.config`` via ``Depends(get_config)``
# - Client: ``DeerFlowClient._app_config``
# - Agent run: ``Runtime[DeerFlowContext].context.app_config``
#
# _global is a plain class variable. Assignment is atomic under the GIL
# (single pointer swap), so no lock is needed for the current read/write
# pattern. If this ever changes to read-modify-write, add a threading.Lock.
_global: ClassVar[AppConfig | None] = None
_override: ClassVar[ContextVar[AppConfig]] = ContextVar("deerflow_app_config_override")
@classmethod
def init(cls, config: AppConfig) -> None:
"""Set the process-global AppConfig. Visible to all subsequent requests."""
cls._global = config
@classmethod
def set_override(cls, config: AppConfig) -> Token[AppConfig]:
"""Set a per-context override. Returns a token for reset_override().
Use this in DeerFlowClient or test fixtures to scope a config to the
current async context without polluting the process-global value.
"""
return cls._override.set(config)
@classmethod
def reset_override(cls, token: Token[AppConfig]) -> None:
"""Restore the override to its previous value."""
cls._override.reset(token)
# ``current()`` is kept as a deprecated no-op slot purely so legacy tests
# that still run ``patch.object(AppConfig, "current", ...)`` can attach
# without an ``AttributeError`` at teardown. Production code never calls
# it — any in-process invocation raises so regressions are loud.
@classmethod
def current(cls) -> AppConfig:
"""Get the current AppConfig.
Priority: per-context override > process-global > auto-load from file.
The auto-load fallback exists for backward compatibility. Prefer calling
``AppConfig.init()`` explicitly at process startup so that config errors
surface early rather than at an arbitrary first-use call site.
"""
try:
return cls._override.get()
except LookupError:
pass
if cls._global is not None:
return cls._global
logger.warning(
"AppConfig.current() called before init(); auto-loading from file. "
"Call AppConfig.init() at process startup to surface config errors early."
raise RuntimeError(
"AppConfig.current() is removed. Pass AppConfig explicitly: "
"`runtime.context.app_config` in agent paths, `Depends(get_config)` in Gateway, "
"`self._app_config` in DeerFlowClient."
)
config = cls.from_file()
cls._global = config
return config

View File

@ -30,43 +30,26 @@ class DeerFlowContext:
def resolve_context(runtime: Any) -> DeerFlowContext:
"""Extract or construct DeerFlowContext from runtime.
"""Return the typed DeerFlowContext that the runtime carries.
Gateway/Client paths: runtime.context is already DeerFlowContext return directly.
LangGraph Server / legacy dict path: construct from dict context or configurable fallback.
Gateway mode (``DeerFlowClient``, ``run_agent``) always attaches a typed
``DeerFlowContext`` via ``agent.astream(context=...)``; the LangGraph
Server path uses ``langgraph.json`` registration where the top-level
``make_lead_agent`` loads ``AppConfig`` from disk itself, so we still
arrive here with a typed context.
Only the dict/None shapes that legacy tests used to exercise would fall
through this function; we now reject them loudly instead of papering
over the missing context with an ambient ``AppConfig`` lookup.
"""
ctx = getattr(runtime, "context", None)
if isinstance(ctx, DeerFlowContext):
return ctx
from deerflow.config.app_config import AppConfig
# Try dict context first (legacy path, tests), then configurable
if isinstance(ctx, dict):
thread_id = ctx.get("thread_id", "")
if not thread_id:
logger.warning("resolve_context: dict context has empty thread_id — may cause incorrect path resolution")
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=thread_id,
agent_name=ctx.get("agent_name"),
)
# No context at all — fall back to LangGraph configurable
try:
from langgraph.config import get_config
cfg = get_config().get("configurable", {})
except RuntimeError:
# Outside runnable context (e.g. unit tests)
cfg = {}
thread_id = cfg.get("thread_id", "")
if not thread_id:
logger.warning("resolve_context: falling back to empty thread_id — no DeerFlowContext or configurable found")
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=thread_id,
agent_name=cfg.get("agent_name"),
raise RuntimeError(
"resolve_context: runtime.context is not a DeerFlowContext "
"(got type %s). Every entry point must attach one at invoke time — "
"Gateway/Client via agent.astream(context=DeerFlowContext(...)), "
"LangGraph Server via the make_lead_agent boundary that loads "
"AppConfig.from_file()." % type(ctx).__name__
)

View File

@ -34,24 +34,18 @@ def create_chat_model(
name: str | None = None,
thinking_enabled: bool = False,
*,
app_config: "AppConfig | None" = None,
app_config: "AppConfig",
**kwargs,
) -> BaseChatModel:
"""Create a chat model instance from the config.
Args:
name: The name of the model to create. If None, the first model in the config will be used.
app_config: Application config. Falls back to AppConfig.current() when
omitted; new callers should pass this explicitly.
app_config: Application config required.
Returns:
A chat model instance.
"""
if app_config is None:
# TODO(P2-10): fold into a required parameter once all callers
# (memory updater, summarization middleware's implicit model) thread
# config explicitly.
app_config = AppConfig.current()
config = app_config
if name is None:
name = config.models[0].name

View File

@ -123,11 +123,11 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
@contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
async def make_checkpointer(app_config: AppConfig) -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state::
async with make_checkpointer() as checkpointer:
async with make_checkpointer(app_config) as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
@ -138,16 +138,14 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
3. Default InMemorySaver
"""
config = AppConfig.current()
# Legacy: standalone checkpointer config takes precedence
if config.checkpointer is not None:
async with _async_checkpointer(config.checkpointer) as saver:
if app_config.checkpointer is not None:
async with _async_checkpointer(app_config.checkpointer) as saver:
yield saver
return
# Unified database config
db_config = getattr(config, "database", None)
db_config = getattr(app_config, "database", None)
if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver:
yield saver

View File

@ -99,10 +99,13 @@ _checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
def get_checkpointer() -> Checkpointer:
def get_checkpointer(app_config: AppConfig) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Returns an ``InMemorySaver`` only when ``checkpointer`` is explicitly
absent from config.yaml. Any other failure (missing config, invalid
backend, connection error) propagates silent degradation to in-memory
would drop persistent-run state on process restart.
Raises:
ImportError: If the required package for the configured backend is not installed.
@ -113,10 +116,7 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None:
return _checkpointer
try:
config = AppConfig.current().checkpointer
except (LookupError, FileNotFoundError):
config = None
config = app_config.checkpointer
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
@ -152,25 +152,23 @@ def reset_checkpointer() -> None:
@contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]:
def checkpointer_context(app_config: AppConfig) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance
each ``with`` block creates and destroys its own connection. Use it in
CLI scripts or tests where you want deterministic cleanup::
with checkpointer_context() as cp:
with checkpointer_context(app_config) as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = AppConfig.current()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
with _sync_checkpointer_cm(config.checkpointer) as saver:
with _sync_checkpointer_cm(app_config.checkpointer) as saver:
yield saver

View File

@ -169,8 +169,10 @@ async def run_agent(
# Construct typed context for the agent run.
# LangGraph's astream(context=...) injects this into Runtime.context
# so middleware/tools can access it via resolve_context().
if ctx.app_config is None:
raise RuntimeError("RunContext.app_config is required — Gateway must populate it via get_run_context")
deer_flow_context = DeerFlowContext(
app_config=ctx.app_config if ctx.app_config is not None else AppConfig.current(),
app_config=ctx.app_config,
thread_id=thread_id,
)

View File

@ -86,7 +86,7 @@ async def _async_store(config) -> AsyncIterator[BaseStore]:
@contextlib.asynccontextmanager
async def make_store() -> AsyncIterator[BaseStore]:
async def make_store(app_config: AppConfig) -> AsyncIterator[BaseStore]:
"""Async context manager that yields a Store whose backend matches the
configured checkpointer.
@ -94,20 +94,18 @@ async def make_store() -> AsyncIterator[BaseStore]:
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
that both singletons always use the same persistence technology::
async with make_store() as store:
async with make_store(app_config) as store:
app.state.store = store
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case).
"""
config = AppConfig.current()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
async with _async_store(config.checkpointer) as store:
async with _async_store(app_config.checkpointer) as store:
yield store

View File

@ -100,7 +100,7 @@ _store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive
def get_store() -> BaseStore:
def get_store(app_config: AppConfig) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@ -115,10 +115,10 @@ def get_store() -> BaseStore:
if _store is not None:
return _store
try:
config = AppConfig.current().checkpointer
except (LookupError, FileNotFoundError):
config = None
# See matching comment in checkpointer/provider.py: a missing config.yaml
# is a deployment error, not a cue to silently pick InMemoryStore. Only
# the explicit "no checkpointer section" path falls through to memory.
config = app_config.checkpointer
if config is None:
from langgraph.store.memory import InMemoryStore
@ -154,26 +154,25 @@ def reset_store() -> None:
@contextlib.contextmanager
def store_context() -> Iterator[BaseStore]:
def store_context(app_config: AppConfig) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance each
``with`` block creates and destroys its own connection. Use it in CLI
scripts or tests where you want deterministic cleanup::
with store_context() as store:
with store_context(app_config) as store:
store.put(("threads",), thread_id, {...})
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = AppConfig.current()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
with _sync_store_cm(config.checkpointer) as store:
with _sync_store_cm(app_config.checkpointer) as store:
yield store

View File

@ -25,14 +25,13 @@ logger = logging.getLogger(__name__)
@contextlib.asynccontextmanager
async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
async def make_stream_bridge(app_config: AppConfig) -> AsyncIterator[StreamBridge]:
"""Async context manager that yields a :class:`StreamBridge`.
Falls back to :class:`MemoryStreamBridge` when no configuration is
provided and nothing is set globally.
Falls back to :class:`MemoryStreamBridge` when no ``stream_bridge``
section is configured.
"""
if config is None:
config = AppConfig.current().stream_bridge
config = app_config.stream_bridge
if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge

View File

@ -1,18 +1,23 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_singleton: LocalSandbox | None = None
class LocalSandboxProvider(SandboxProvider):
def __init__(self):
def __init__(self, app_config: "AppConfig"):
"""Initialize the local sandbox provider with path mappings."""
self._app_config = app_config
self._path_mappings = self._setup_path_mappings()
def _setup_path_mappings(self) -> list[PathMapping]:
@ -29,9 +34,7 @@ class LocalSandboxProvider(SandboxProvider):
# Map skills container path to local skills directory
try:
from deerflow.config.app_config import AppConfig
config = AppConfig.current()
config = self._app_config
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path

View File

@ -43,8 +43,8 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
super().__init__()
self._lazy_init = lazy_init
def _acquire_sandbox(self, thread_id: str) -> str:
provider = get_sandbox_provider()
def _acquire_sandbox(self, thread_id: str, runtime: Runtime[DeerFlowContext]) -> str:
provider = get_sandbox_provider(runtime.context.app_config)
sandbox_id = provider.acquire(thread_id)
logger.info(f"Acquiring sandbox {sandbox_id}")
return sandbox_id
@ -60,7 +60,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
thread_id = runtime.context.thread_id
if not thread_id:
return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id)
sandbox_id = self._acquire_sandbox(thread_id, runtime)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
return {"sandbox": {"sandbox_id": sandbox_id}}
return super().before_agent(state, runtime)
@ -71,7 +71,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
if sandbox is not None:
sandbox_id = sandbox["sandbox_id"]
logger.info(f"Releasing sandbox {sandbox_id}")
get_sandbox_provider().release(sandbox_id)
get_sandbox_provider(runtime.context.app_config).release(sandbox_id)
return None
# No sandbox to release

View File

@ -39,23 +39,38 @@ class SandboxProvider(ABC):
_default_sandbox_provider: SandboxProvider | None = None
def get_sandbox_provider(**kwargs) -> SandboxProvider:
def get_sandbox_provider(app_config: AppConfig, **kwargs) -> SandboxProvider:
"""Get the sandbox provider singleton.
Returns a cached singleton instance. Use `reset_sandbox_provider()` to clear
the cache, or `shutdown_sandbox_provider()` to properly shutdown and clear.
Args:
app_config: Application config used the first time the singleton is built.
Ignored on subsequent calls the cached instance is returned
regardless of the config passed.
Returns:
A sandbox provider instance.
"""
global _default_sandbox_provider
if _default_sandbox_provider is None:
config = AppConfig.current()
cls = resolve_class(config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls(**kwargs)
cls = resolve_class(app_config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls(app_config=app_config, **kwargs) if _accepts_app_config(cls) else cls(**kwargs)
return _default_sandbox_provider
def _accepts_app_config(cls: type) -> bool:
"""Return True when the provider's __init__ accepts an ``app_config`` kwarg."""
import inspect
try:
sig = inspect.signature(cls.__init__)
except (TypeError, ValueError):
return False
return "app_config" in sig.parameters
def reset_sandbox_provider() -> None:
"""Reset the sandbox provider singleton.

View File

@ -20,11 +20,8 @@ LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE = (
)
def uses_local_sandbox_provider(config=None) -> bool:
def uses_local_sandbox_provider(config: AppConfig) -> bool:
"""Return True when the active sandbox provider is the host-local provider."""
if config is None:
config = AppConfig.current()
sandbox_cfg = getattr(config, "sandbox", None)
sandbox_use = getattr(sandbox_cfg, "use", "")
if sandbox_use in _LOCAL_SANDBOX_PROVIDER_MARKERS:
@ -32,11 +29,8 @@ def uses_local_sandbox_provider(config=None) -> bool:
return sandbox_use.endswith(":LocalSandboxProvider") and "deerflow.sandbox.local" in sandbox_use
def is_host_bash_allowed(config=None) -> bool:
def is_host_bash_allowed(config: AppConfig) -> bool:
"""Return whether host bash execution is explicitly allowed."""
if config is None:
config = AppConfig.current()
sandbox_cfg = getattr(config, "sandbox", None)
if sandbox_cfg is None:
return True

View File

@ -40,58 +40,43 @@ _DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500
def _get_skills_container_path() -> str:
"""Get the skills container path from config, with fallback to default.
Result is cached after the first successful config load. If config loading
fails the default is returned *without* caching so that a later call can
pick up the real value once the config is available.
"""
cached = getattr(_get_skills_container_path, "_cached", None)
if cached is not None:
return cached
try:
value = AppConfig.current().skills.container_path
_get_skills_container_path._cached = value # type: ignore[attr-defined]
return value
except Exception:
def _get_skills_container_path(app_config: AppConfig) -> str:
"""Get the skills container path from config, with fallback to default."""
skills_cfg = getattr(app_config, "skills", None)
if skills_cfg is None:
return _DEFAULT_SKILLS_CONTAINER_PATH
return skills_cfg.container_path
def _get_skills_host_path() -> str | None:
def _get_skills_host_path(app_config: AppConfig) -> str | None:
"""Get the skills host filesystem path from config.
Returns None if the skills directory does not exist or config cannot be
loaded. Only successful lookups are cached; failures are retried on the
next call so that a transiently unavailable skills directory does not
permanently disable skills access.
Returns None if the skills directory does not exist or is not configured.
"""
cached = getattr(_get_skills_host_path, "_cached", None)
if cached is not None:
return cached
skills_cfg = getattr(app_config, "skills", None)
if skills_cfg is None:
return None
try:
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
if skills_path.exists():
value = str(skills_path)
_get_skills_host_path._cached = value # type: ignore[attr-defined]
return value
skills_path = skills_cfg.get_skills_path()
except Exception:
pass
return None
if skills_path.exists():
return str(skills_path)
return None
def _is_skills_path(path: str) -> bool:
def _is_skills_path(path: str, app_config: AppConfig) -> bool:
"""Check if a path is under the skills container path."""
skills_prefix = _get_skills_container_path()
skills_prefix = _get_skills_container_path(app_config)
return path == skills_prefix or path.startswith(f"{skills_prefix}/")
def _resolve_skills_path(path: str) -> str:
def _resolve_skills_path(path: str, app_config: AppConfig) -> str:
"""Resolve a virtual skills path to a host filesystem path.
Args:
path: Virtual skills path (e.g. /mnt/skills/public/bootstrap/SKILL.md)
app_config: Resolved application config.
Returns:
Resolved host path.
@ -99,8 +84,8 @@ def _resolve_skills_path(path: str) -> str:
Raises:
FileNotFoundError: If skills directory is not configured or doesn't exist.
"""
skills_container = _get_skills_container_path()
skills_host = _get_skills_host_path()
skills_container = _get_skills_container_path(app_config)
skills_host = _get_skills_host_path(app_config)
if skills_host is None:
raise FileNotFoundError(f"Skills directory not available for path: {path}")
@ -116,46 +101,31 @@ def _is_acp_workspace_path(path: str) -> bool:
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
def _get_custom_mounts():
def _get_custom_mounts(app_config: AppConfig):
"""Get custom volume mounts from sandbox config.
Result is cached after the first successful config load. If config loading
fails an empty list is returned *without* caching so that a later call can
pick up the real value once the config is available.
Only includes mounts whose host_path exists, consistent with
``LocalSandboxProvider._setup_path_mappings()`` which also filters by
``host_path.exists()``.
"""
cached = getattr(_get_custom_mounts, "_cached", None)
if cached is not None:
return cached
try:
from pathlib import Path
config = AppConfig.current()
mounts = []
if config.sandbox and config.sandbox.mounts:
# Only include mounts whose host_path exists, consistent with
# LocalSandboxProvider._setup_path_mappings() which also filters
# by host_path.exists().
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
return mounts
except Exception:
# If config loading fails, return an empty list without caching so that
# a later call can retry once the config is available.
sandbox_cfg = getattr(app_config, "sandbox", None)
if sandbox_cfg is None or not sandbox_cfg.mounts:
return []
return [m for m in sandbox_cfg.mounts if Path(m.host_path).exists()]
def _is_custom_mount_path(path: str) -> bool:
def _is_custom_mount_path(path: str, app_config: AppConfig) -> bool:
"""Check if path is under a custom mount container_path."""
for mount in _get_custom_mounts():
for mount in _get_custom_mounts(app_config):
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
return True
return False
def _get_custom_mount_for_path(path: str):
def _get_custom_mount_for_path(path: str, app_config: AppConfig):
"""Get the mount config matching this path (longest prefix first)."""
best = None
for mount in _get_custom_mounts():
for mount in _get_custom_mounts(app_config):
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
if best is None or len(mount.container_path) > len(best.container_path):
best = mount
@ -266,42 +236,40 @@ def _resolve_acp_workspace_path(path: str, thread_id: str | None = None) -> str:
return str(resolved_path)
def _get_mcp_allowed_paths() -> list[str]:
def _get_mcp_allowed_paths(app_config: AppConfig) -> list[str]:
"""Get the list of allowed paths from MCP config for file system server."""
allowed_paths = []
try:
extensions_config = AppConfig.current().extensions
allowed_paths: list[str] = []
extensions_config = getattr(app_config, "extensions", None)
if extensions_config is None:
return allowed_paths
for _, server in extensions_config.mcp_servers.items():
if not server.enabled:
continue
for _, server in extensions_config.mcp_servers.items():
if not server.enabled:
continue
# Only check the filesystem server
args = server.args or []
# Check if args has server-filesystem package
has_filesystem = any("server-filesystem" in arg for arg in args)
if not has_filesystem:
continue
# Unpack the allowed file system paths in config
for arg in args:
if not arg.startswith("-") and arg.startswith("/"):
allowed_paths.append(arg.rstrip("/") + "/")
except Exception:
pass
# Only check the filesystem server
args = server.args or []
# Check if args has server-filesystem package
has_filesystem = any("server-filesystem" in arg for arg in args)
if not has_filesystem:
continue
# Unpack the allowed file system paths in config
for arg in args:
if not arg.startswith("-") and arg.startswith("/"):
allowed_paths.append(arg.rstrip("/") + "/")
return allowed_paths
def _get_tool_config_int(name: str, key: str, default: int) -> int:
def _get_tool_config_int(app_config: AppConfig, name: str, key: str, default: int) -> int:
try:
tool_config = AppConfig.current().get_tool_config(name)
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
return value
tool_config = app_config.get_tool_config(name)
except Exception:
pass
return default
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
return value
return default
@ -311,23 +279,23 @@ def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
return min(value, upper_bound)
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int:
def _resolve_max_results(app_config: AppConfig, name: str, requested: int, *, default: int, upper_bound: int) -> int:
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
configured_max_results = _clamp_max_results(
_get_tool_config_int(name, "max_results", default),
_get_tool_config_int(app_config, name, "max_results", default),
default=default,
upper_bound=upper_bound,
)
return min(requested_max_results, configured_max_results)
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str:
validate_local_tool_path(path, thread_data, read_only=True)
if _is_skills_path(path):
return _resolve_skills_path(path)
def _resolve_local_read_path(path: str, thread_data: ThreadDataState, app_config: AppConfig) -> str:
validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path, app_config):
return _resolve_skills_path(path, app_config)
if _is_acp_workspace_path(path):
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
return _resolve_and_validate_user_data_path(path, thread_data)
return _resolve_and_validate_user_data_path(path, thread_data, app_config)
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
@ -373,7 +341,11 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
def _sanitize_error(
error: Exception,
runtime: "ToolRuntime[ContextT, ThreadState] | None" = None,
app_config: AppConfig | None = None,
) -> str:
"""Sanitize an error message to avoid leaking host filesystem paths.
In local-sandbox mode, resolved host paths in the error string are masked
@ -382,8 +354,12 @@ def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadStat
"""
msg = f"{type(error).__name__}: {error}"
if runtime is not None and is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
msg = mask_local_paths_in_output(msg, thread_data)
if app_config is None:
ctx = getattr(runtime, "context", None)
app_config = getattr(ctx, "app_config", None)
if app_config is not None:
thread_data = get_thread_data(runtime)
msg = mask_local_paths_in_output(msg, thread_data, app_config)
return msg
@ -453,7 +429,7 @@ def _thread_actual_to_virtual_mappings(thread_data: ThreadDataState) -> dict[str
return {actual: virtual for virtual, actual in _thread_virtual_to_actual_mappings(thread_data).items()}
def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None) -> str:
def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> str:
"""Mask host absolute paths from local sandbox output using virtual paths.
Handles user-data paths (per-thread), skills paths, and ACP workspace paths (global).
@ -461,8 +437,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
result = output
# Mask skills host paths
skills_host = _get_skills_host_path()
skills_container = _get_skills_container_path()
skills_host = _get_skills_host_path(app_config)
skills_container = _get_skills_container_path(app_config)
if skills_host:
raw_base = str(Path(skills_host))
resolved_base = str(Path(skills_host).resolve())
@ -536,7 +512,13 @@ def _reject_path_traversal(path: str) -> None:
raise PermissionError("Access denied: path traversal detected")
def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *, read_only: bool = False) -> None:
def validate_local_tool_path(
path: str,
thread_data: ThreadDataState | None,
app_config: AppConfig,
*,
read_only: bool = False,
) -> None:
"""Validate that a virtual path is allowed for local-sandbox access.
This function is a security gate it checks whether *path* may be
@ -565,7 +547,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
_reject_path_traversal(path)
# Skills paths — read-only access only
if _is_skills_path(path):
if _is_skills_path(path, app_config):
if not read_only:
raise PermissionError(f"Write access to skills path is not allowed: {path}")
return
@ -581,13 +563,13 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
return
# Custom mount paths — respect read_only config
if _is_custom_mount_path(path):
mount = _get_custom_mount_for_path(path)
if _is_custom_mount_path(path, app_config):
mount = _get_custom_mount_for_path(path, app_config)
if mount and mount.read_only and not read_only:
raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
return
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path(app_config)}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
@ -618,18 +600,23 @@ def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataSta
raise PermissionError("Access denied: path traversal detected")
def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState) -> str:
def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState, app_config: AppConfig) -> str:
"""Resolve a /mnt/user-data virtual path and validate it stays in bounds.
Returns the resolved host path string.
``app_config`` is accepted for signature symmetry with the other resolver
helpers; the user-data resolution path itself is fully derivable from
``thread_data``.
"""
_ = app_config # noqa: F841 — kept for interface symmetry with sibling resolvers
resolved_str = replace_virtual_path(path, thread_data)
resolved = Path(resolved_str).resolve()
_validate_resolved_user_data_path(resolved, thread_data)
return str(resolved)
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None) -> None:
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> None:
"""Validate absolute paths in local-sandbox bash commands.
This validation is only a best-effort guard for the explicit
@ -653,7 +640,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
raise PermissionError(f"Unsafe file:// URL in command: {file_url_match.group()}. Use paths under {VIRTUAL_PATH_PREFIX}")
unsafe_paths: list[str] = []
allowed_paths = _get_mcp_allowed_paths()
allowed_paths = _get_mcp_allowed_paths(app_config)
for absolute_path in _ABSOLUTE_PATH_PATTERN.findall(command):
# Check for MCP filesystem server allowed paths
@ -666,7 +653,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
continue
# Allow skills container path (resolved by tools.py before passing to sandbox)
if _is_skills_path(absolute_path):
if _is_skills_path(absolute_path, app_config):
_reject_path_traversal(absolute_path)
continue
@ -676,7 +663,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
continue
# Allow custom mount container paths
if _is_custom_mount_path(absolute_path):
if _is_custom_mount_path(absolute_path, app_config):
_reject_path_traversal(absolute_path)
continue
@ -690,12 +677,13 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
raise PermissionError(f"Unsafe absolute paths in command: {unsafe}. Use paths under {VIRTUAL_PATH_PREFIX}")
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None) -> str:
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> str:
"""Replace all virtual paths (/mnt/user-data, /mnt/skills, /mnt/acp-workspace) in a command string.
Args:
command: The command string that may contain virtual paths.
thread_data: The thread data containing actual paths.
app_config: Resolved application config.
Returns:
The command with all virtual paths replaced.
@ -703,13 +691,13 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
result = command
# Replace skills paths
skills_container = _get_skills_container_path()
skills_host = _get_skills_host_path()
skills_container = _get_skills_container_path(app_config)
skills_host = _get_skills_host_path(app_config)
if skills_host and skills_container in result:
skills_pattern = re.compile(rf"{re.escape(skills_container)}(/[^\s\"';&|<>()]*)?")
def replace_skills_match(match: re.Match) -> str:
return _resolve_skills_path(match.group(0))
return _resolve_skills_path(match.group(0), app_config)
result = skills_pattern.sub(replace_skills_match, result)
@ -799,7 +787,7 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
sandbox_id = sandbox_state.get("sandbox_id")
if sandbox_id is None:
raise SandboxRuntimeError("Sandbox ID not found in state")
sandbox = get_sandbox_provider().get(sandbox_id)
sandbox = get_sandbox_provider(resolve_context(runtime).app_config).get(sandbox_id)
if sandbox is None:
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
@ -830,12 +818,14 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if runtime.state is None:
raise SandboxRuntimeError("Tool runtime state not available")
app_config = runtime.context.app_config
# Check if sandbox already exists in state
sandbox_state = runtime.state.get("sandbox")
if sandbox_state is not None:
sandbox_id = sandbox_state.get("sandbox_id")
if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id)
sandbox = get_sandbox_provider(app_config).get(sandbox_id)
if sandbox is not None:
return sandbox
# Sandbox was released, fall through to acquire new one
@ -845,7 +835,7 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if not thread_id:
raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider()
provider = get_sandbox_provider(app_config)
sandbox_id = provider.acquire(thread_id)
# Update runtime state - this persists across tool calls
@ -985,9 +975,9 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
command: The bash command to execute. Always use absolute paths for files and directories.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
app_config = resolve_context(runtime).app_config
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
if is_local_sandbox(runtime):
@ -995,11 +985,11 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
ensure_thread_directories_exist(runtime)
thread_data = get_thread_data(runtime)
validate_local_bash_command_paths(command, thread_data)
command = replace_virtual_paths_in_command(command, thread_data)
validate_local_bash_command_paths(command, thread_data, app_config)
command = replace_virtual_paths_in_command(command, thread_data, app_config)
command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command)
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data, app_config), max_chars)
ensure_thread_directories_exist(runtime)
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
except SandboxError as e:
@ -1007,7 +997,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime, app_config)}"
@tool("ls", parse_docstring=True)
@ -1018,25 +1008,26 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the directory to list.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True)
if _is_skills_path(path):
path = _resolve_skills_path(path)
validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path, app_config):
path = _resolve_skills_path(path, app_config)
elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
elif not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
elif not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
children = sandbox.list_dir(path)
if not children:
return "(empty)"
output = "\n".join(children)
sandbox_cfg = resolve_context(runtime).app_config.sandbox
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
return _truncate_ls_output(output, max_chars)
except SandboxError as e:
@ -1046,7 +1037,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime, app_config)}"
@tool("glob", parse_docstring=True)
@ -1067,11 +1058,13 @@ def glob_tool(
include_dirs: Whether matching directories should also be returned. Default is False.
max_results: Maximum number of paths to return. Default is 200.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
effective_max_results = _resolve_max_results(
app_config,
"glob",
max_results,
default=_DEFAULT_GLOB_MAX_RESULTS,
@ -1082,10 +1075,10 @@ def glob_tool(
thread_data = get_thread_data(runtime)
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data)
path = _resolve_local_read_path(path, thread_data, app_config)
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
if thread_data is not None:
matches = [mask_local_paths_in_output(match, thread_data) for match in matches]
matches = [mask_local_paths_in_output(match, thread_data, app_config) for match in matches]
return _format_glob_results(requested_path, matches, truncated)
except SandboxError as e:
return f"Error: {e}"
@ -1096,7 +1089,7 @@ def glob_tool(
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime, app_config)}"
@tool("grep", parse_docstring=True)
@ -1121,11 +1114,13 @@ def grep_tool(
case_sensitive: Whether matching is case-sensitive. Default is False.
max_results: Maximum number of matching lines to return. Default is 100.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
effective_max_results = _resolve_max_results(
app_config,
"grep",
max_results,
default=_DEFAULT_GREP_MAX_RESULTS,
@ -1136,7 +1131,7 @@ def grep_tool(
thread_data = get_thread_data(runtime)
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data)
path = _resolve_local_read_path(path, thread_data, app_config)
matches, truncated = sandbox.grep(
path,
pattern,
@ -1148,7 +1143,7 @@ def grep_tool(
if thread_data is not None:
matches = [
GrepMatch(
path=mask_local_paths_in_output(match.path, thread_data),
path=mask_local_paths_in_output(match.path, thread_data, app_config),
line_number=match.line_number,
line=match.line,
)
@ -1166,7 +1161,7 @@ def grep_tool(
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime, app_config)}"
@tool("read_file", parse_docstring=True)
@ -1185,26 +1180,27 @@ def read_file_tool(
start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range.
end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True)
if _is_skills_path(path):
path = _resolve_skills_path(path)
validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path, app_config):
path = _resolve_skills_path(path, app_config)
elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
elif not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
elif not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
content = sandbox.read_file(path)
if not content:
return "(empty)"
if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
sandbox_cfg = resolve_context(runtime).app_config.sandbox
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
return _truncate_read_file_output(content, max_chars)
except SandboxError as e:
@ -1216,7 +1212,7 @@ def read_file_tool(
except IsADirectoryError:
return f"Error: Path is a directory, not a file: {requested_path}"
except Exception as e:
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime, app_config)}"
@tool("write_file", parse_docstring=True)
@ -1234,15 +1230,16 @@ def write_file_tool(
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
if not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
validate_local_tool_path(path, thread_data, app_config)
if not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path):
sandbox.write_file(path, content, append)
@ -1254,9 +1251,9 @@ def write_file_tool(
except IsADirectoryError:
return f"Error: Path is a directory, not a file: {requested_path}"
except OSError as e:
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}"
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime, app_config)}"
except Exception as e:
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime, app_config)}"
@tool("str_replace", parse_docstring=True)
@ -1278,15 +1275,16 @@ def str_replace_tool(
new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH.
replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
if not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
validate_local_tool_path(path, thread_data, app_config)
if not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path):
content = sandbox.read_file(path)
@ -1307,4 +1305,4 @@ def str_replace_tool(
except PermissionError:
return f"Error: Permission denied accessing file: {requested_path}"
except Exception as e:
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime, app_config)}"

View File

@ -1,10 +1,14 @@
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING
from .parser import parse_skill_file
from .types import Skill
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@ -22,7 +26,12 @@ def get_skills_root_path() -> Path:
return skills_dir
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
def load_skills(
app_config: "AppConfig | None" = None,
*,
skills_path: Path | None = None,
enabled_only: bool = False,
) -> list[Skill]:
"""
Load all skills from the skills directory.
@ -30,25 +39,19 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
to extract metadata. The enabled state is determined by the skills_state_config.json file.
Args:
skills_path: Optional custom path to skills directory.
If not provided and use_config is True, uses path from config.
Otherwise defaults to deer-flow/skills
use_config: Whether to load skills path from config (default: True)
app_config: Application config used to resolve the configured skills
directory. Ignored when ``skills_path`` is supplied.
skills_path: Explicit override for the skills directory. When both
``skills_path`` and ``app_config`` are omitted the
default repository layout is used (``deer-flow/skills``).
enabled_only: If True, only return enabled skills (default: False)
Returns:
List of Skill objects, sorted by name
"""
if skills_path is None:
if use_config:
try:
from deerflow.config.app_config import AppConfig
config = AppConfig.current()
skills_path = config.skills.get_skills_path()
except Exception:
# Fallback to default if config fails
skills_path = get_skills_root_path()
if app_config is not None:
skills_path = app_config.skills.get_skills_path()
else:
skills_path = get_skills_root_path()

View File

@ -20,16 +20,17 @@ ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path:
return AppConfig.current().skills.get_skills_path()
def get_skills_root_dir(app_config: AppConfig) -> Path:
"""Return the configured skills root."""
return app_config.skills.get_skills_path()
def get_public_skills_dir() -> Path:
return get_skills_root_dir() / "public"
def get_public_skills_dir(app_config: AppConfig) -> Path:
return get_skills_root_dir(app_config) / "public"
def get_custom_skills_dir() -> Path:
path = get_skills_root_dir() / "custom"
def get_custom_skills_dir(app_config: AppConfig) -> Path:
path = get_skills_root_dir(app_config) / "custom"
path.mkdir(parents=True, exist_ok=True)
return path
@ -43,46 +44,46 @@ def validate_skill_name(name: str) -> str:
return normalized
def get_custom_skill_dir(name: str) -> Path:
return get_custom_skills_dir() / validate_skill_name(name)
def get_custom_skill_dir(name: str, app_config: AppConfig) -> Path:
return get_custom_skills_dir(app_config) / validate_skill_name(name)
def get_custom_skill_file(name: str) -> Path:
return get_custom_skill_dir(name) / SKILL_FILE_NAME
def get_custom_skill_file(name: str, app_config: AppConfig) -> Path:
return get_custom_skill_dir(name, app_config) / SKILL_FILE_NAME
def get_custom_skill_history_dir() -> Path:
path = get_custom_skills_dir() / HISTORY_DIR_NAME
def get_custom_skill_history_dir(app_config: AppConfig) -> Path:
path = get_custom_skills_dir(app_config) / HISTORY_DIR_NAME
path.mkdir(parents=True, exist_ok=True)
return path
def get_skill_history_file(name: str) -> Path:
return get_custom_skill_history_dir() / f"{validate_skill_name(name)}.jsonl"
def get_skill_history_file(name: str, app_config: AppConfig) -> Path:
return get_custom_skill_history_dir(app_config) / f"{validate_skill_name(name)}.jsonl"
def get_public_skill_dir(name: str) -> Path:
return get_public_skills_dir() / validate_skill_name(name)
def get_public_skill_dir(name: str, app_config: AppConfig) -> Path:
return get_public_skills_dir(app_config) / validate_skill_name(name)
def custom_skill_exists(name: str) -> bool:
return get_custom_skill_file(name).exists()
def custom_skill_exists(name: str, app_config: AppConfig) -> bool:
return get_custom_skill_file(name, app_config).exists()
def public_skill_exists(name: str) -> bool:
return (get_public_skill_dir(name) / SKILL_FILE_NAME).exists()
def public_skill_exists(name: str, app_config: AppConfig) -> bool:
return (get_public_skill_dir(name, app_config) / SKILL_FILE_NAME).exists()
def ensure_custom_skill_is_editable(name: str) -> None:
if custom_skill_exists(name):
def ensure_custom_skill_is_editable(name: str, app_config: AppConfig) -> None:
if custom_skill_exists(name, app_config):
return
if public_skill_exists(name):
if public_skill_exists(name, app_config):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise FileNotFoundError(f"Custom skill '{name}' not found.")
def ensure_safe_support_path(name: str, relative_path: str) -> Path:
skill_dir = get_custom_skill_dir(name).resolve()
def ensure_safe_support_path(name: str, relative_path: str, app_config: AppConfig) -> Path:
skill_dir = get_custom_skill_dir(name, app_config).resolve()
if not relative_path or relative_path.endswith("/"):
raise ValueError("Supporting file path must include a filename.")
relative = Path(relative_path)
@ -124,8 +125,8 @@ def atomic_write(path: Path, content: str) -> None:
tmp_path.replace(path)
def append_history(name: str, record: dict[str, Any]) -> None:
history_path = get_skill_history_file(name)
def append_history(name: str, record: dict[str, Any], app_config: AppConfig) -> None:
history_path = get_skill_history_file(name, app_config)
history_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"ts": datetime.now(UTC).isoformat(),
@ -136,8 +137,8 @@ def append_history(name: str, record: dict[str, Any]) -> None:
f.write("\n")
def read_history(name: str) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name)
def read_history(name: str, app_config: AppConfig) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name, app_config)
if not history_path.exists():
return []
records: list[dict[str, Any]] = []
@ -148,12 +149,12 @@ def read_history(name: str) -> list[dict[str, Any]]:
return records
def list_custom_skills() -> list:
return [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
def list_custom_skills(app_config: AppConfig) -> list:
return [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
def read_custom_skill_content(name: str) -> str:
skill_file = get_custom_skill_file(name)
def read_custom_skill_content(name: str, app_config: AppConfig) -> str:
skill_file = get_custom_skill_file(name, app_config)
if not skill_file.exists():
raise FileNotFoundError(f"Custom skill '{name}' not found.")
return skill_file.read_text(encoding="utf-8")

View File

@ -35,7 +35,7 @@ def _extract_json_object(raw: str) -> dict | None:
return None
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
async def scan_skill_content(app_config: AppConfig, content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
"""Screen skill content before it is written to disk."""
rubric = (
"You are a security reviewer for AI agent skills. "
@ -47,9 +47,12 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try:
config = AppConfig.current()
model_name = config.skill_evolution.moderation_model_name
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
model_name = app_config.skill_evolution.moderation_model_name
model = (
create_chat_model(name=model_name, thinking_enabled=False, app_config=app_config)
if model_name
else create_chat_model(thinking_enabled=False, app_config=app_config)
)
response = await model.ainvoke(
[
{"role": "system", "content": rubric},

View File

@ -17,6 +17,7 @@ from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
from deerflow.subagents.config import SubagentConfig
@ -132,24 +133,16 @@ class SubagentExecutor:
self,
config: SubagentConfig,
tools: list[BaseTool],
app_config: AppConfig,
parent_model: str | None = None,
sandbox_state: SandboxState | None = None,
thread_data: ThreadDataState | None = None,
thread_id: str | None = None,
trace_id: str | None = None,
):
"""Initialize the executor.
Args:
config: Subagent configuration.
tools: List of all available tools (will be filtered).
parent_model: The parent agent's model name for inheritance.
sandbox_state: Sandbox state from parent agent.
thread_data: Thread data from parent agent.
thread_id: Thread ID for sandbox operations.
trace_id: Trace ID from parent for distributed tracing.
"""
"""Initialize the executor."""
self.config = config
self.app_config = app_config
self.parent_model = parent_model
self.sandbox_state = sandbox_state
self.thread_data = thread_data
@ -169,7 +162,7 @@ class SubagentExecutor:
def _create_agent(self):
"""Create the agent instance."""
model_name = _get_model_name(self.config, self.parent_model)
model = create_chat_model(name=model_name, thinking_enabled=False)
model = create_chat_model(name=model_name, thinking_enabled=False, app_config=self.app_config)
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares

View File

@ -3,6 +3,7 @@
import logging
from dataclasses import replace
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.config import SubagentConfig
@ -10,25 +11,15 @@ from deerflow.subagents.config import SubagentConfig
logger = logging.getLogger(__name__)
def get_subagent_config(name: str) -> SubagentConfig | None:
"""Get a subagent configuration by name, with config.yaml overrides applied.
Args:
name: The name of the subagent.
Returns:
SubagentConfig if found (with any config.yaml overrides applied), None otherwise.
"""
def get_subagent_config(name: str, app_config: AppConfig) -> SubagentConfig | None:
"""Get a subagent configuration by name, with config.yaml overrides applied."""
config = BUILTIN_SUBAGENTS.get(name)
if config is None:
return None
# Apply timeout override from config.yaml (lazy import to avoid circular deps)
from deerflow.config.app_config import AppConfig
app_config = AppConfig.current().subagents
effective_timeout = app_config.get_timeout_for(name)
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)
sub_config = app_config.subagents
effective_timeout = sub_config.get_timeout_for(name)
effective_max_turns = sub_config.get_max_turns_for(name, config.max_turns)
overrides = {}
if effective_timeout != config.timeout_seconds:
@ -53,13 +44,13 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
return config
def list_subagents() -> list[SubagentConfig]:
def list_subagents(app_config: AppConfig) -> list[SubagentConfig]:
"""List all available subagent configurations (with config.yaml overrides applied).
Returns:
List of all registered SubagentConfig instances.
"""
return [get_subagent_config(name) for name in BUILTIN_SUBAGENTS]
return [get_subagent_config(name, app_config) for name in BUILTIN_SUBAGENTS]
def get_subagent_names() -> list[str]:
@ -71,7 +62,7 @@ def get_subagent_names() -> list[str]:
return list(BUILTIN_SUBAGENTS.keys())
def get_available_subagent_names() -> list[str]:
def get_available_subagent_names(app_config: AppConfig) -> list[str]:
"""Get subagent names that should be exposed to the active runtime.
Returns:
@ -79,7 +70,7 @@ def get_available_subagent_names() -> list[str]:
"""
names = list(BUILTIN_SUBAGENTS.keys())
try:
host_bash_allowed = is_host_bash_allowed()
host_bash_allowed = is_host_bash_allowed(app_config)
except Exception:
logger.debug("Could not determine host bash availability; exposing all built-in subagents")
return names

View File

@ -60,20 +60,21 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
"""
available_subagent_names = get_available_subagent_names()
ctx = resolve_context(runtime)
available_subagent_names = get_available_subagent_names(ctx.app_config)
# Get subagent configuration
config = get_subagent_config(subagent_type)
config = get_subagent_config(subagent_type, ctx.app_config)
if config is None:
available = ", ".join(available_subagent_names)
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
if subagent_type == "bash" and not is_host_bash_allowed(resolve_context(runtime).app_config):
if subagent_type == "bash" and not is_host_bash_allowed(ctx.app_config):
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
# Build config overrides
overrides: dict = {}
skills_section = get_skills_prompt_section()
skills_section = get_skills_prompt_section(ctx.app_config)
if skills_section:
overrides["system_prompt"] = config.system_prompt + "\n\n" + skills_section
@ -107,12 +108,13 @@ async def task_tool(
from deerflow.tools import get_available_tools
# Subagents should not have subagent tools enabled (prevent recursive nesting)
tools = get_available_tools(model_name=parent_model, subagent_enabled=False)
tools = get_available_tools(model_name=parent_model, subagent_enabled=False, app_config=ctx.app_config)
# Create executor
executor = SubagentExecutor(
config=config,
tools=tools,
app_config=ctx.app_config,
parent_model=parent_model,
sandbox_state=sandbox_state,
thread_data=thread_data,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio
import logging
import shutil
from typing import Any
from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary
from langchain.tools import ToolRuntime, tool
@ -13,6 +13,9 @@ from langgraph.typing import ContextT
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.agents.thread_state import ThreadState
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.manager import (
append_history,
@ -60,8 +63,8 @@ def _history_record(*, action: str, file_path: str, prev_content: str | None, ne
}
async def _scan_or_raise(content: str, *, executable: bool, location: str) -> dict[str, str]:
result = await scan_skill_content(content, executable=executable, location=location)
async def _scan_or_raise(app_config: "AppConfig", content: str, *, executable: bool, location: str) -> dict[str, str]:
result = await scan_skill_content(app_config, content, executable=executable, location=location)
if result.decision == "block":
raise ValueError(f"Security scan blocked the write: {result.reason}")
if executable and result.decision != "allow":
@ -94,50 +97,55 @@ async def _skill_manage_impl(
replace: Replacement text for patch.
expected_count: Optional expected number of replacements for patch.
"""
from deerflow.config.deer_flow_context import resolve_context
name = validate_skill_name(name)
lock = _get_lock(name)
thread_id = _get_thread_id(runtime)
app_config = resolve_context(runtime).app_config
async with lock:
if action == "create":
if await _to_thread(custom_skill_exists, name):
if await _to_thread(custom_skill_exists, name, app_config):
raise ValueError(f"Custom skill '{name}' already exists.")
if content is None:
raise ValueError("content is required for create.")
await _to_thread(validate_skill_markdown_content, name, content)
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md")
skill_file = await _to_thread(get_custom_skill_file, name)
scan = await _scan_or_raise(app_config, content, executable=False, location=f"{name}/SKILL.md")
skill_file = await _to_thread(get_custom_skill_file, name, app_config)
await _to_thread(atomic_write, skill_file, content)
await _to_thread(
append_history,
name,
_history_record(action="create", file_path="SKILL.md", prev_content=None, new_content=content, thread_id=thread_id, scanner=scan),
app_config,
)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return f"Created custom skill '{name}'."
if action == "edit":
await _to_thread(ensure_custom_skill_is_editable, name)
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
if content is None:
raise ValueError("content is required for edit.")
await _to_thread(validate_skill_markdown_content, name, content)
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md")
skill_file = await _to_thread(get_custom_skill_file, name)
scan = await _scan_or_raise(app_config, content, executable=False, location=f"{name}/SKILL.md")
skill_file = await _to_thread(get_custom_skill_file, name, app_config)
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
await _to_thread(atomic_write, skill_file, content)
await _to_thread(
append_history,
name,
_history_record(action="edit", file_path="SKILL.md", prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
app_config,
)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return f"Updated custom skill '{name}'."
if action == "patch":
await _to_thread(ensure_custom_skill_is_editable, name)
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
if find is None or replace is None:
raise ValueError("find and replace are required for patch.")
skill_file = await _to_thread(get_custom_skill_file, name)
skill_file = await _to_thread(get_custom_skill_file, name, app_config)
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
occurrences = prev_content.count(find)
if occurrences == 0:
@ -147,51 +155,54 @@ async def _skill_manage_impl(
replacement_count = expected_count if expected_count is not None else 1
new_content = prev_content.replace(find, replace, replacement_count)
await _to_thread(validate_skill_markdown_content, name, new_content)
scan = await _scan_or_raise(new_content, executable=False, location=f"{name}/SKILL.md")
scan = await _scan_or_raise(app_config, new_content, executable=False, location=f"{name}/SKILL.md")
await _to_thread(atomic_write, skill_file, new_content)
await _to_thread(
append_history,
name,
_history_record(action="patch", file_path="SKILL.md", prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan),
app_config,
)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return f"Patched custom skill '{name}' ({replacement_count} replacement(s) applied, {occurrences} match(es) found)."
if action == "delete":
await _to_thread(ensure_custom_skill_is_editable, name)
skill_dir = await _to_thread(get_custom_skill_dir, name)
prev_content = await _to_thread(read_custom_skill_content, name)
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
skill_dir = await _to_thread(get_custom_skill_dir, name, app_config)
prev_content = await _to_thread(read_custom_skill_content, name, app_config)
await _to_thread(
append_history,
name,
_history_record(action="delete", file_path="SKILL.md", prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
app_config,
)
await _to_thread(shutil.rmtree, skill_dir)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return f"Deleted custom skill '{name}'."
if action == "write_file":
await _to_thread(ensure_custom_skill_is_editable, name)
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
if path is None or content is None:
raise ValueError("path and content are required for write_file.")
target = await _to_thread(ensure_safe_support_path, name, path)
target = await _to_thread(ensure_safe_support_path, name, path, app_config)
exists = await _to_thread(target.exists)
prev_content = await _to_thread(target.read_text, encoding="utf-8") if exists else None
executable = "scripts/" in path or path.startswith("scripts/")
scan = await _scan_or_raise(content, executable=executable, location=f"{name}/{path}")
scan = await _scan_or_raise(app_config, content, executable=executable, location=f"{name}/{path}")
await _to_thread(atomic_write, target, content)
await _to_thread(
append_history,
name,
_history_record(action="write_file", file_path=path, prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
app_config,
)
return f"Wrote '{path}' for custom skill '{name}'."
if action == "remove_file":
await _to_thread(ensure_custom_skill_is_editable, name)
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
if path is None:
raise ValueError("path is required for remove_file.")
target = await _to_thread(ensure_safe_support_path, name, path)
target = await _to_thread(ensure_safe_support_path, name, path, app_config)
if not await _to_thread(target.exists):
raise FileNotFoundError(f"Supporting file '{path}' not found for skill '{name}'.")
prev_content = await _to_thread(target.read_text, encoding="utf-8")
@ -200,10 +211,11 @@ async def _skill_manage_impl(
append_history,
name,
_history_record(action="remove_file", file_path=path, prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
app_config,
)
return f"Removed '{path}' from custom skill '{name}'."
if await _to_thread(public_skill_exists, name):
if await _to_thread(public_skill_exists, name, app_config):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise ValueError(f"Unsupported action '{action}'.")

View File

@ -38,7 +38,7 @@ def get_available_tools(
model_name: str | None = None,
subagent_enabled: bool = False,
*,
app_config: AppConfig | None = None,
app_config: AppConfig,
) -> list[BaseTool]:
"""Get all available tools from config.
@ -50,16 +50,11 @@ def get_available_tools(
include_mcp: Whether to include tools from MCP servers (default: True).
model_name: Optional model name to determine if vision tools should be included.
subagent_enabled: Whether to include subagent tools (task, task_status).
app_config: Explicit application config. Falls back to AppConfig.current()
when omitted; new callers should pass this explicitly.
app_config: Application config required.
Returns:
List of available tools.
"""
if app_config is None:
# TODO(P2-10): fold into a required parameter once all callers thread
# config explicitly (community tool factories, subagent registry, etc.).
app_config = AppConfig.current()
config = app_config
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]

View File

@ -133,7 +133,7 @@ def _do_convert(file_path: Path, pdf_converter: str) -> str:
return _convert_with_markitdown(file_path)
async def convert_file_to_markdown(file_path: Path) -> Path | None:
async def convert_file_to_markdown(file_path: Path, app_config: object | None = None) -> Path | None:
"""Convert a supported document file to Markdown.
PDF files are handled with a two-converter strategy (see module docstring).
@ -142,12 +142,14 @@ async def convert_file_to_markdown(file_path: Path) -> Path | None:
Args:
file_path: Path to the file to convert.
app_config: Optional AppConfig (for pdf_converter preference). When
omitted, defaults to ``auto``.
Returns:
Path to the generated .md file, or None if conversion failed.
"""
try:
pdf_converter = _get_pdf_converter()
pdf_converter = _get_pdf_converter(app_config)
file_size = file_path.stat().st_size
if file_size > _ASYNC_THRESHOLD_BYTES:
@ -286,24 +288,20 @@ def extract_outline(md_path: Path) -> list[dict]:
return outline
def _get_pdf_converter() -> str:
def _get_pdf_converter(app_config: object | None) -> str:
"""Read pdf_converter setting from app config, defaulting to 'auto'.
Normalizes the value to lowercase and validates it against the allowed set
so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently
fall through to unexpected behaviour.
"""
try:
from deerflow.config.app_config import AppConfig
cfg = AppConfig.current()
uploads_cfg = getattr(cfg, "uploads", None)
if uploads_cfg is not None:
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
if raw not in _ALLOWED_PDF_CONVERTERS:
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
return "auto"
return raw
except Exception:
pass
return "auto"
if app_config is None:
return "auto"
uploads_cfg = getattr(app_config, "uploads", None)
if uploads_cfg is None:
return "auto"
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
if raw not in _ALLOWED_PDF_CONVERTERS:
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
return "auto"
return raw

View File

@ -69,12 +69,18 @@ def provisioner_module():
@pytest.fixture(autouse=True)
def _auto_app_config():
"""Initialize a minimal AppConfig for tests so ``AppConfig.current()`` never tries to auto-load config.yaml.
def _auto_app_config_from_file(monkeypatch, request):
"""Replace ``AppConfig.from_file`` with a minimal factory so tests that
(directly or indirectly, e.g. via the LangGraph Server bootstrap path in
``make_lead_agent``) load AppConfig from disk do not need a real
``config.yaml`` on the filesystem.
Individual tests can still override via ``patch.object(AppConfig, "current", ...)``
or by calling ``AppConfig.init()`` with a different config.
Tests that want to verify the real ``from_file`` behaviour should mark
themselves with ``@pytest.mark.real_from_file``.
"""
if request.node.get_closest_marker("real_from_file"):
yield
return
try:
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
@ -82,12 +88,11 @@ def _auto_app_config():
yield
return
previous_global = AppConfig._global
AppConfig._global = AppConfig(sandbox=SandboxConfig(use="test"))
try:
yield
finally:
AppConfig._global = previous_global
def _fake_from_file(config_path: str | None = None) -> AppConfig: # noqa: ARG001
return AppConfig(sandbox=SandboxConfig(use="test"))
monkeypatch.setattr(AppConfig, "from_file", _fake_from_file)
yield
@pytest.fixture(autouse=True)

View File

@ -2,8 +2,11 @@
import json
import pytest
import pytest
import yaml
pytestmark = pytest.mark.real_from_file
from pydantic import ValidationError
from deerflow.config.acp_config import ACPAgentConfig

View File

@ -3,10 +3,13 @@ from __future__ import annotations
import json
from pathlib import Path
import pytest
import yaml
from deerflow.config.app_config import AppConfig
pytestmark = pytest.mark.real_from_file
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
path.write_text(
@ -31,7 +34,11 @@ def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
def test_init_then_get(tmp_path, monkeypatch):
def test_from_file_reads_model_name(tmp_path, monkeypatch):
"""``AppConfig.from_file`` is the only lifecycle method now; there is no
process-global ``init/current``. Each consumer holds its own captured
AppConfig instance.
"""
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
@ -41,14 +48,13 @@ def test_init_then_get(tmp_path, monkeypatch):
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config = AppConfig.from_file(str(config_path))
AppConfig.init(config)
result = AppConfig.current()
assert result is config
assert result.models[0].name == "test-model"
assert config.models[0].name == "test-model"
def test_init_replaces_previous(tmp_path, monkeypatch):
def test_from_file_each_call_returns_fresh_instance(tmp_path, monkeypatch):
"""Two reads of the same file produce separate AppConfig instances —
no hidden singleton, no memoization. Callers decide when to re-read.
"""
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
@ -58,14 +64,12 @@ def test_init_replaces_previous(tmp_path, monkeypatch):
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config_a = AppConfig.from_file(str(config_path))
AppConfig.init(config_a)
assert AppConfig.current().models[0].name == "model-a"
assert config_a.models[0].name == "model-a"
_write_config(config_path, model_name="model-b", supports_thinking=True)
config_b = AppConfig.from_file(str(config_path))
AppConfig.init(config_b)
assert AppConfig.current().models[0].name == "model-b"
assert AppConfig.current() is config_b
assert config_b.models[0].name == "model-b"
assert config_a is not config_b
def test_config_version_check(tmp_path, monkeypatch):

View File

@ -64,39 +64,42 @@ class TestGetCheckpointer:
from langgraph.checkpoint.memory import InMemorySaver
with patch.object(AppConfig, "current", return_value=_make_config()):
cp = get_checkpointer()
cp = get_checkpointer(AppConfig.current())
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_returns_in_memory_saver_when_config_not_found(self):
from langgraph.checkpoint.memory import InMemorySaver
def test_raises_when_config_file_missing(self):
"""A missing config.yaml is a deployment error, not a cue to degrade to InMemorySaver.
Silent degradation would drop persistent-run state on process restart.
`get_checkpointer` only falls back to InMemorySaver for the explicit
`checkpointer: null` opt-in (test above), not for I/O failure.
"""
with patch.object(AppConfig, "current", side_effect=FileNotFoundError):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
with pytest.raises(FileNotFoundError):
get_checkpointer(AppConfig.current())
def test_memory_returns_in_memory_saver(self):
from langgraph.checkpoint.memory import InMemorySaver
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp = get_checkpointer()
cp = get_checkpointer(AppConfig.current())
assert isinstance(cp, InMemorySaver)
def test_memory_singleton(self):
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp1 = get_checkpointer()
cp2 = get_checkpointer()
cp1 = get_checkpointer(AppConfig.current())
cp2 = get_checkpointer(AppConfig.current())
assert cp1 is cp2
def test_reset_clears_singleton(self):
cfg = _make_config(CheckpointerConfig(type="memory"))
with patch.object(AppConfig, "current", return_value=cfg):
cp1 = get_checkpointer()
cp1 = get_checkpointer(AppConfig.current())
reset_checkpointer()
cp2 = get_checkpointer()
cp2 = get_checkpointer(AppConfig.current())
assert cp1 is not cp2
def test_sqlite_raises_when_package_missing(self):
@ -107,7 +110,7 @@ class TestGetCheckpointer:
):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
get_checkpointer()
get_checkpointer(AppConfig.current())
def test_postgres_raises_when_package_missing(self):
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
@ -117,7 +120,7 @@ class TestGetCheckpointer:
):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
get_checkpointer()
get_checkpointer(AppConfig.current())
def test_postgres_raises_when_connection_string_missing(self):
cfg = _make_config(CheckpointerConfig(type="postgres"))
@ -130,7 +133,7 @@ class TestGetCheckpointer:
):
reset_checkpointer()
with pytest.raises(ValueError, match="connection_string is required"):
get_checkpointer()
get_checkpointer(AppConfig.current())
def test_sqlite_creates_saver(self):
"""SQLite checkpointer is created when package is available."""
@ -152,7 +155,7 @@ class TestGetCheckpointer:
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
):
reset_checkpointer()
cp = get_checkpointer()
cp = get_checkpointer(AppConfig.current())
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once()
@ -178,7 +181,7 @@ class TestGetCheckpointer:
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}),
):
reset_checkpointer()
cp = get_checkpointer()
cp = get_checkpointer(AppConfig.current())
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
@ -214,7 +217,7 @@ class TestAsyncCheckpointer:
return_value="/tmp/resolved/test.db",
),
):
async with make_checkpointer() as saver:
async with make_checkpointer(AppConfig.current()) as saver:
assert saver is mock_saver
mock_to_thread.assert_awaited_once()
@ -245,7 +248,7 @@ class TestAppConfigLoadsCheckpointer:
class TestClientCheckpointerFallback:
def test_client_uses_config_checkpointer_when_none_provided(self):
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
"""DeerFlowClient._ensure_agent falls back to get_checkpointer(AppConfig.current()) when checkpointer=None."""
# This is a structural test — verifying the fallback path exists.
cfg = _make_config(CheckpointerConfig(type="memory"))
assert cfg.checkpointer is not None

View File

@ -22,7 +22,7 @@ class TestCheckpointerNoneFix:
mock_config.database = None
with patch.object(AppConfig, "current", return_value=mock_config):
async with make_checkpointer() as checkpointer:
async with make_checkpointer(AppConfig.current()) as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
assert isinstance(checkpointer, InMemorySaver)
@ -45,7 +45,7 @@ class TestCheckpointerNoneFix:
mock_config.checkpointer = None
with patch.object(AppConfig, "current", return_value=mock_config):
with checkpointer_context() as checkpointer:
with checkpointer_context(AppConfig.current()) as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
assert isinstance(checkpointer, InMemorySaver)

View File

@ -44,9 +44,12 @@ def mock_app_config():
@pytest.fixture
def client(mock_app_config):
"""Create a DeerFlowClient with mocked config loading."""
with patch.object(AppConfig, "current", return_value=mock_app_config):
return DeerFlowClient()
"""Create a DeerFlowClient holding the mocked config directly.
Passing ``config=`` is the documented post-refactor way to inject a
test AppConfig; nothing relies on process-global state.
"""
return DeerFlowClient(config=mock_app_config)
# ---------------------------------------------------------------------------
@ -124,7 +127,7 @@ class TestConfigQueries:
with patch("deerflow.skills.loader.load_skills", return_value=[skill]) as mock_load:
result = client.list_skills()
mock_load.assert_called_once_with(enabled_only=False)
mock_load.assert_called_once_with(client._app_config, enabled_only=False)
assert "skills" in result
assert len(result["skills"]) == 1
@ -139,7 +142,7 @@ class TestConfigQueries:
def test_list_skills_enabled_only(self, client):
with patch("deerflow.skills.loader.load_skills", return_value=[]) as mock_load:
client.list_skills(enabled_only=True)
mock_load.assert_called_once_with(enabled_only=True)
mock_load.assert_called_once_with(client._app_config, enabled_only=True)
def test_get_memory(self, client):
memory = {"version": "1.0", "facts": []}
@ -1244,7 +1247,7 @@ class TestMemoryManagement:
assert mock_import.call_count == 1
call_args = mock_import.call_args
assert call_args.args == (imported,)
assert call_args.args == (client._app_config.memory, imported)
assert "user_id" in call_args.kwargs
assert result == imported
@ -1269,6 +1272,7 @@ class TestMemoryManagement:
confidence=0.88,
)
create_fact.assert_called_once_with(
client._app_config.memory,
content="User prefers concise code reviews.",
category="preference",
confidence=0.88,
@ -1279,7 +1283,7 @@ class TestMemoryManagement:
data = {"version": "1.0", "facts": []}
with patch("deerflow.agents.memory.updater.delete_memory_fact", return_value=data) as delete_fact:
result = client.delete_memory_fact("fact_123")
delete_fact.assert_called_once_with("fact_123")
delete_fact.assert_called_once_with(client._app_config.memory, "fact_123")
assert result == data
def test_update_memory_fact(self, client):
@ -1292,6 +1296,7 @@ class TestMemoryManagement:
confidence=0.91,
)
update_fact.assert_called_once_with(
client._app_config.memory,
fact_id="fact_123",
content="User prefers spaces",
category="workflow",
@ -1307,6 +1312,7 @@ class TestMemoryManagement:
"User prefers spaces",
)
update_fact.assert_called_once_with(
client._app_config.memory,
fact_id="fact_123",
content="User prefers spaces",
category=None,
@ -1802,7 +1808,7 @@ class TestScenarioConfigManagement:
reloaded_config.mcp_servers = {"my-mcp": reloaded_server}
client._agent = MagicMock() # Simulate existing agent
AppConfig.init(MagicMock(extensions=current_config))
client._app_config = MagicMock(extensions=current_config)
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
@ -2220,8 +2226,7 @@ class TestGatewayConformance:
model.supports_thinking = False
mock_app_config.models = [model]
with patch.object(AppConfig, "current", return_value=mock_app_config):
client = DeerFlowClient()
client = DeerFlowClient(config=mock_app_config)
result = client.list_models()
parsed = ModelsListResponse(**result)
@ -2239,8 +2244,7 @@ class TestGatewayConformance:
mock_app_config.models = [model]
mock_app_config.get_model_config.return_value = model
with patch.object(AppConfig, "current", return_value=mock_app_config):
client = DeerFlowClient()
client = DeerFlowClient(config=mock_app_config)
result = client.get_model("test-model")
assert result is not None
@ -3076,7 +3080,7 @@ class TestBugAgentInvalidationInconsistency:
config_file = Path(tmp) / "ext.json"
config_file.write_text("{}")
AppConfig.init(MagicMock(extensions=current_config))
client._app_config = MagicMock(extensions=current_config)
with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded)),

View File

@ -71,3 +71,25 @@ def test_config_model_rejects_mutation(model_cls: type[BaseModel]):
with pytest.raises(ValidationError):
setattr(instance, first_field, "MUTATED")
def test_extensions_nested_dict_mutation_is_not_blocked_by_pydantic():
"""Regression guard: Pydantic `frozen=True` does NOT deep-freeze container fields.
This test documents the trap callers MUST compose a new dict and persist
it + reload AppConfig instead of reaching into `extensions.skills[x]`.
If you need the dict to be truly immutable, wrap with Mapping/frozendict.
"""
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
ext = ExtensionsConfig(mcp_servers={}, skills={"a": SkillStateConfig(enabled=True)})
# This is the pre-refactor anti-pattern: Pydantic lets it through because
# the outer model is frozen but the inner dict is a plain builtin. No error.
ext.skills["a"] = SkillStateConfig(enabled=False)
ext.skills["b"] = SkillStateConfig(enabled=True)
# The test asserts the leak exists so a future "add deep-freeze" change
# flips this expectation and forces call-site review.
assert ext.skills["a"].enabled is False
assert "b" in ext.skills

View File

@ -10,6 +10,9 @@ import yaml
from fastapi.testclient import TestClient
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
_TEST_MEMORY_CONFIG = MemoryConfig()
# ---------------------------------------------------------------------------
# Helpers
@ -335,7 +338,7 @@ class TestMemoryFilePath:
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path = storage._get_memory_file_path(None)
assert path == tmp_path / "memory.json"
@ -348,7 +351,7 @@ class TestMemoryFilePath:
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path = storage._get_memory_file_path("code-reviewer")
assert path == tmp_path / "agents" / "code-reviewer" / "memory.json"
@ -360,7 +363,7 @@ class TestMemoryFilePath:
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path_global = storage._get_memory_file_path(None)
path_a = storage._get_memory_file_path("agent-a")
path_b = storage._get_memory_file_path("agent-b")

View File

@ -47,40 +47,16 @@ class TestResolveContext:
runtime.context = ctx
assert resolve_context(runtime) is ctx
def test_fallback_from_configurable(self):
"""LangGraph Server path: runtime.context is None → construct from ContextVar + configurable."""
def test_raises_on_none_context(self):
"""Without a typed DeerFlowContext, resolve_context refuses to guess."""
runtime = MagicMock()
runtime.context = None
config = _make_config()
with (
patch.object(AppConfig, "current", return_value=config),
patch("langgraph.config.get_config", return_value={"configurable": {"thread_id": "t2", "agent_name": "ag"}}),
):
ctx = resolve_context(runtime)
assert ctx.thread_id == "t2"
assert ctx.agent_name == "ag"
assert ctx.app_config is config
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
resolve_context(runtime)
def test_fallback_empty_configurable(self):
"""LangGraph Server path with no thread_id in configurable → empty string."""
runtime = MagicMock()
runtime.context = None
config = _make_config()
with (
patch.object(AppConfig, "current", return_value=config),
patch("langgraph.config.get_config", return_value={"configurable": {}}),
):
ctx = resolve_context(runtime)
assert ctx.thread_id == ""
assert ctx.agent_name is None
def test_fallback_from_dict_context(self):
"""Legacy path: runtime.context is a dict → extract from dict directly."""
def test_raises_on_dict_context(self):
"""Legacy dict shape is no longer supported — we raise instead of lazily loading AppConfig."""
runtime = MagicMock()
runtime.context = {"thread_id": "old-dict", "agent_name": "from-dict"}
config = _make_config()
with patch.object(AppConfig, "current", return_value=config):
ctx = resolve_context(runtime)
assert ctx.thread_id == "old-dict"
assert ctx.agent_name == "from-dict"
assert ctx.app_config is config
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
resolve_context(runtime)

View File

@ -7,6 +7,31 @@ import pytest
from deerflow.config.app_config import AppConfig
# --- Phase 2 test helper: injected runtime for community tools ---
from types import SimpleNamespace as _P2NS
from deerflow.config.app_config import AppConfig as _P2AppConfig
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
def _runtime_with_config(config):
"""Build a runtime carrying a custom (possibly mocked) app_config.
``DeerFlowContext`` is a frozen dataclass typed as ``AppConfig`` but
dataclasses don't enforce the type at runtime — handing a Mock through
lets tests exercise the tool's ``get_tool_config`` lookup without going
via ``AppConfig.current``.
"""
ctx = _P2Ctx.__new__(_P2Ctx)
object.__setattr__(ctx, "app_config", config)
object.__setattr__(ctx, "thread_id", "test-thread")
object.__setattr__(ctx, "agent_name", None)
return _P2NS(context=ctx)
# -------------------------------------------------------------------
@pytest.fixture
def mock_app_config():
@ -51,7 +76,7 @@ class TestWebSearchTool:
from deerflow.community.exa.tools import web_search_tool
result = web_search_tool.invoke({"query": "test query"})
result = web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
parsed = json.loads(result)
assert len(parsed) == 2
@ -69,30 +94,30 @@ class TestWebSearchTool:
def test_search_with_custom_config(self, mock_exa_client):
"""Test search respects custom configuration values."""
with patch.object(AppConfig, "current") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 10,
"search_type": "neural",
"contents_max_characters": 2000,
"api_key": "test-key",
}
mock_config.return_value.get_tool_config.return_value = tool_config
tool_config = MagicMock()
tool_config.model_extra = {
"max_results": 10,
"search_type": "neural",
"contents_max_characters": 2000,
"api_key": "test-key",
}
fake_config = MagicMock()
fake_config.get_tool_config.return_value = tool_config
mock_response = MagicMock()
mock_response.results = []
mock_exa_client.search.return_value = mock_response
mock_response = MagicMock()
mock_response.results = []
mock_exa_client.search.return_value = mock_response
from deerflow.community.exa.tools import web_search_tool
from deerflow.community.exa.tools import web_search_tool
web_search_tool.invoke({"query": "neural search"})
web_search_tool.func(query="neural search", runtime=_runtime_with_config(fake_config))
mock_exa_client.search.assert_called_once_with(
"neural search",
type="neural",
num_results=10,
contents={"highlights": {"max_characters": 2000}},
)
mock_exa_client.search.assert_called_once_with(
"neural search",
type="neural",
num_results=10,
contents={"highlights": {"max_characters": 2000}},
)
def test_search_with_no_highlights(self, mock_app_config, mock_exa_client):
"""Test search handles results with no highlights."""
@ -107,7 +132,7 @@ class TestWebSearchTool:
from deerflow.community.exa.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
result = web_search_tool.func(query="test", runtime=_P2_RUNTIME)
parsed = json.loads(result)
assert parsed[0]["snippet"] == ""
@ -120,7 +145,7 @@ class TestWebSearchTool:
from deerflow.community.exa.tools import web_search_tool
result = web_search_tool.invoke({"query": "nothing"})
result = web_search_tool.func(query="nothing", runtime=_P2_RUNTIME)
parsed = json.loads(result)
assert parsed == []
@ -131,7 +156,7 @@ class TestWebSearchTool:
from deerflow.community.exa.tools import web_search_tool
result = web_search_tool.invoke({"query": "error"})
result = web_search_tool.func(query="error", runtime=_P2_RUNTIME)
assert result == "Error: API rate limit exceeded"
@ -149,7 +174,7 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
assert result == "# Fetched Page\n\nThis is the page content."
mock_exa_client.get_contents.assert_called_once_with(
@ -169,7 +194,7 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
assert result.startswith("# Untitled\n\n")
@ -181,7 +206,7 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com/404"})
result = web_fetch_tool.func(url="https://example.com/404", runtime=_P2_RUNTIME)
assert result == "Error: No results found"
@ -191,16 +216,44 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
assert result == "Error: Connection timeout"
def test_fetch_reads_web_fetch_config(self, mock_exa_client):
"""Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'."""
with patch.object(AppConfig, "current") as mock_config:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "exa-fetch-key"}
mock_config.return_value.get_tool_config.return_value = tool_config
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "exa-fetch-key"}
fake_config = MagicMock()
fake_config.get_tool_config.return_value = tool_config
mock_result = MagicMock()
mock_result.title = "Page"
mock_result.text = "Content."
mock_response = MagicMock()
mock_response.results = [mock_result]
mock_exa_client.get_contents.return_value = mock_response
from deerflow.community.exa.tools import web_fetch_tool
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
fake_config.get_tool_config.assert_any_call("web_fetch")
def test_fetch_uses_independent_api_key(self, mock_exa_client):
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
mock_exa_cls.return_value = mock_exa_client
fetch_config = MagicMock()
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
def get_tool_config(name):
if name == "web_fetch":
return fetch_config
return None
fake_config = MagicMock()
fake_config.get_tool_config.side_effect = get_tool_config
mock_result = MagicMock()
mock_result.title = "Page"
@ -211,37 +264,9 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
web_fetch_tool.invoke({"url": "https://example.com"})
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
mock_config.return_value.get_tool_config.assert_any_call("web_fetch")
def test_fetch_uses_independent_api_key(self, mock_exa_client):
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
with patch.object(AppConfig, "current") as mock_config:
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
mock_exa_cls.return_value = mock_exa_client
fetch_config = MagicMock()
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
def get_tool_config(name):
if name == "web_fetch":
return fetch_config
return None
mock_config.return_value.get_tool_config.side_effect = get_tool_config
mock_result = MagicMock()
mock_result.title = "Page"
mock_result.text = "Content."
mock_response = MagicMock()
mock_response.results = [mock_result]
mock_exa_client.get_contents.return_value = mock_response
from deerflow.community.exa.tools import web_fetch_tool
web_fetch_tool.invoke({"url": "https://example.com"})
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
def test_fetch_truncates_long_content(self, mock_app_config, mock_exa_client):
"""Test fetch truncates content to 4096 characters."""
@ -255,7 +280,7 @@ class TestWebFetchTool:
from deerflow.community.exa.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
# "# Long Page\n\n" is 14 chars, content truncated to 4096
content_after_header = result.split("\n\n", 1)[1]

View File

@ -3,16 +3,31 @@
import json
from unittest.mock import MagicMock, patch
from deerflow.config.app_config import AppConfig
from types import SimpleNamespace as _P2NS
from deerflow.config.app_config import AppConfig as _P2AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
def _runtime_with_config(config):
ctx = _P2Ctx.__new__(_P2Ctx)
object.__setattr__(ctx, "app_config", config)
object.__setattr__(ctx, "thread_id", "test-thread")
object.__setattr__(ctx, "agent_name", None)
return _P2NS(context=ctx)
class TestWebSearchTool:
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
@patch.object(AppConfig, "current")
def test_search_uses_web_search_config(self, mock_get_app_config, mock_firecrawl_cls):
def test_search_uses_web_search_config(self, mock_firecrawl_cls):
search_config = MagicMock()
search_config.model_extra = {"api_key": "firecrawl-search-key", "max_results": 7}
mock_get_app_config.return_value.get_tool_config.return_value = search_config
fake_config = MagicMock()
fake_config.get_tool_config.return_value = search_config
mock_result = MagicMock()
mock_result.web = [
@ -22,7 +37,7 @@ class TestWebSearchTool:
from deerflow.community.firecrawl.tools import web_search_tool
result = web_search_tool.invoke({"query": "test query"})
result = web_search_tool.func(query="test query", runtime=_runtime_with_config(fake_config))
assert json.loads(result) == [
{
@ -31,15 +46,14 @@ class TestWebSearchTool:
"snippet": "Snippet",
}
]
mock_get_app_config.return_value.get_tool_config.assert_called_with("web_search")
fake_config.get_tool_config.assert_called_with("web_search")
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-search-key")
mock_firecrawl_cls.return_value.search.assert_called_once_with("test query", limit=7)
class TestWebFetchTool:
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
@patch.object(AppConfig, "current")
def test_fetch_uses_web_fetch_config(self, mock_get_app_config, mock_firecrawl_cls):
def test_fetch_uses_web_fetch_config(self, mock_firecrawl_cls):
fetch_config = MagicMock()
fetch_config.model_extra = {"api_key": "firecrawl-fetch-key"}
@ -48,7 +62,8 @@ class TestWebFetchTool:
return fetch_config
return None
mock_get_app_config.return_value.get_tool_config.side_effect = get_tool_config
fake_config = MagicMock()
fake_config.get_tool_config.side_effect = get_tool_config
mock_scrape_result = MagicMock()
mock_scrape_result.markdown = "Fetched markdown"
@ -57,10 +72,10 @@ class TestWebFetchTool:
from deerflow.community.firecrawl.tools import web_fetch_tool
result = web_fetch_tool.invoke({"url": "https://example.com"})
result = web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
assert result == "# Fetched Page\n\nFetched markdown"
mock_get_app_config.return_value.get_tool_config.assert_any_call("web_fetch")
fake_config.get_tool_config.assert_any_call("web_fetch")
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-fetch-key")
mock_firecrawl_cls.return_value.scrape.assert_called_once_with(
"https://example.com",

View File

@ -7,6 +7,16 @@ from deerflow.community.infoquest import tools
from deerflow.community.infoquest.infoquest_client import InfoQuestClient
from deerflow.config.app_config import AppConfig
# --- Phase 2 test helper: injected runtime for community tools ---
from types import SimpleNamespace as _P2NS
from deerflow.config.app_config import AppConfig as _P2AppConfig
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
# -------------------------------------------------------------------
class TestInfoQuestClient:
def test_infoquest_client_initialization(self):
@ -131,7 +141,7 @@ class TestInfoQuestClient:
mock_client.web_search.return_value = json.dumps([])
mock_get_client.return_value = mock_client
result = tools.web_search_tool.run("test query")
result = tools.web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
assert result == json.dumps([])
mock_get_client.assert_called_once()
@ -144,7 +154,7 @@ class TestInfoQuestClient:
mock_client.fetch.return_value = "<html><body>Test content</body></html>"
mock_get_client.return_value = mock_client
result = tools.web_fetch_tool.run("https://example.com")
result = tools.web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
assert result == "# Untitled\n\nTest content"
mock_get_client.assert_called_once()
@ -162,7 +172,7 @@ class TestInfoQuestClient:
]
mock_get.return_value = mock_config
client = tools._get_infoquest_client()
client = tools._get_infoquest_client(mock_config)
assert client.search_time_range == 24
assert client.fetch_time == 10
@ -322,7 +332,7 @@ class TestImageSearch:
mock_client.image_search.return_value = json.dumps([{"image_url": "https://example.com/image1.jpg"}])
mock_get_client.return_value = mock_client
result = tools.image_search_tool.run({"query": "test query"})
result = tools.image_search_tool.func(query="test query", runtime=_P2_RUNTIME)
# Check if result is a valid JSON string
result_data = json.loads(result)
@ -341,7 +351,7 @@ class TestImageSearch:
mock_get_client.return_value = mock_client
# Pass all parameters as a dictionary (extra parameters will be ignored)
tools.image_search_tool.run({"query": "sunset", "time_range": 30, "site": "unsplash.com", "image_size": "l"})
tools.image_search_tool.func(query="sunset", runtime=_P2_RUNTIME)
mock_get_client.assert_called_once()
# image_search_tool only passes query to client.image_search

View File

@ -685,5 +685,5 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
)
tools = get_available_tools(include_mcp=True, subagent_enabled=False)
tools = get_available_tools(include_mcp=True, subagent_enabled=False, app_config=AppConfig.current())
assert "invoke_acp_agent" in [tool.name for tool in tools]

View File

@ -11,6 +11,16 @@ from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.community.jina_ai.tools import web_fetch_tool
from deerflow.config.app_config import AppConfig
# --- Phase 2 test helper: injected runtime for community tools ---
from types import SimpleNamespace as _P2NS
from deerflow.config.app_config import AppConfig as _P2AppConfig
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
# -------------------------------------------------------------------
@pytest.fixture
def jina_client():
@ -157,7 +167,7 @@ async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
mock_config.get_tool_config.return_value = None
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
assert result.startswith("Error:")
assert "429" in result
@ -173,6 +183,6 @@ async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
mock_config.get_tool_config.return_value = None
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
result = await web_fetch_tool.ainvoke("https://example.com")
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
assert "Hello world" in result
assert not result.startswith("Error:")

View File

@ -98,7 +98,8 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
"is_plan_mode": False,
"subagent_enabled": False,
}
}
},
app_config=app_config,
)
assert captured["name"] == "safe-model"
@ -122,7 +123,6 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
]
)
AppConfig.init(app_config)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda _ac: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
@ -137,7 +137,6 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
app_config = _make_app_config([_make_model("default", supports_thinking=False)])
patched = app_config.model_copy(update={"summarization": SummarizationConfig(enabled=True, model_name="model-masswork")})
AppConfig.init(patched)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: patched))
from unittest.mock import MagicMock

View File

@ -8,22 +8,19 @@ from deerflow.config.app_config import AppConfig
from deerflow.skills.types import Skill
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
def test_build_custom_mounts_section_returns_empty_when_no_mounts():
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
assert prompt_module._build_custom_mounts_section() == ""
assert prompt_module._build_custom_mounts_section(config) == ""
def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
def test_build_custom_mounts_section_lists_configured_mounts():
mounts = [
SimpleNamespace(container_path="/home/user/shared", read_only=False),
SimpleNamespace(container_path="/mnt/reference", read_only=True),
]
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
section = prompt_module._build_custom_mounts_section()
section = prompt_module._build_custom_mounts_section(config)
assert "**Custom Mounted Directories:**" in section
assert "`/home/user/shared`" in section
@ -37,15 +34,15 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
config = SimpleNamespace(
sandbox=SimpleNamespace(mounts=mounts),
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
prompt = prompt_module.apply_prompt_template()
prompt = prompt_module.apply_prompt_template(config)
assert "`/home/user/shared`" in prompt
assert "Custom Mounted Directories" in prompt
@ -55,15 +52,15 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
config = SimpleNamespace(
sandbox=SimpleNamespace(mounts=[]),
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
prompt = prompt_module.apply_prompt_template()
prompt = prompt_module.apply_prompt_template(config)
assert "Treat `/mnt/user-data/workspace` as your default current working directory" in prompt
assert "`hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`" in prompt
@ -84,7 +81,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
)
state = {"skills": [make_skill("first-skill")]}
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
monkeypatch.setattr(prompt_module, "load_skills", lambda *a, **kwargs: list(state["skills"]))
prompt_module._reset_skills_system_prompt_cache_state()
try:
@ -120,7 +117,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
enabled=True,
)
def fake_load_skills(enabled_only=True):
def fake_load_skills(*a, **kwargs):
nonlocal active_loads, max_active_loads, call_count
with lock:
active_loads += 1
@ -157,7 +154,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog):
event = threading.Event()
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda: event)
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda *a, **k: event)
with caplog.at_level("WARNING"):
warmed = prompt_module.warm_enabled_skills_cache(timeout_seconds=0.01)

View File

@ -20,27 +20,40 @@ def _make_skill(name: str) -> Skill:
)
_DEFAULT_SKILLS_CONFIG = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
def _evolution_enabled_config() -> SimpleNamespace:
return SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
)
def test_get_skills_prompt_section_returns_empty_when_no_skills_match(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills={"non_existent_skill"})
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"non_existent_skill"})
assert result == ""
def test_get_skills_prompt_section_returns_empty_when_available_skills_empty(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills=set())
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=set())
assert result == ""
def test_get_skills_prompt_section_returns_skills(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills={"skill1"})
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"skill1"})
assert "skill1" in result
assert "skill2" not in result
assert "[built-in]" in result
@ -48,56 +61,41 @@ def test_get_skills_prompt_section_returns_skills(monkeypatch):
def test_get_skills_prompt_section_returns_all_when_available_skills_is_none(monkeypatch):
skills = [_make_skill("skill1"), _make_skill("skill2")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills=None)
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=None)
assert "skill1" in result
assert "skill2" in result
def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
skills = [_make_skill("skill1")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
monkeypatch.setattr(
AppConfig, "current",
staticmethod(lambda: SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
)),
)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
result = get_skills_prompt_section(available_skills=None)
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
assert "Skill Self-Evolution" in result
def test_get_skills_prompt_section_includes_self_evolution_rules_without_skills(monkeypatch):
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [])
monkeypatch.setattr(
AppConfig, "current",
staticmethod(lambda: SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
)),
)
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: [])
result = get_skills_prompt_section(available_skills=None)
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
assert "Skill Self-Evolution" in result
def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeypatch):
skills = [_make_skill("skill1")]
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=True),
)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
config = _evolution_enabled_config()
enabled_result = get_skills_prompt_section(available_skills=None)
enabled_result = get_skills_prompt_section(config, available_skills=None)
assert "Skill Self-Evolution" in enabled_result
config.skill_evolution.enabled = False
disabled_result = get_skills_prompt_section(available_skills=None)
disabled_config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
disabled_result = get_skills_prompt_section(disabled_config, available_skills=None)
assert "Skill Self-Evolution" not in disabled_result
@ -123,7 +121,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
captured_skills = []
def mock_apply_prompt_template(**kwargs):
def mock_apply_prompt_template(_app_config, *args, **kwargs):
captured_skills.append(kwargs.get("available_skills"))
return "mock_prompt"
@ -131,15 +129,15 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
# Case 1: Empty skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=[]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
assert captured_skills[-1] == set()
# Case 2: None skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
assert captured_skills[-1] is None
# Case 3: Some skills list
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
assert captured_skills[-1] == {"skill1"}

View File

@ -28,7 +28,7 @@ def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
)
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=AppConfig.current())]
assert "bash" not in names
assert "ls" in names
@ -41,7 +41,7 @@ def test_get_available_tools_keeps_bash_when_explicitly_enabled(monkeypatch):
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
)
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=AppConfig.current())]
assert "bash" in names
assert "ls" in names
@ -58,7 +58,7 @@ def test_get_available_tools_hides_renamed_host_bash_alias(monkeypatch):
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
)
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=AppConfig.current())]
assert "bash" not in names
assert "shell" not in names
@ -76,7 +76,7 @@ def test_get_available_tools_keeps_bash_for_aio_sandbox(monkeypatch):
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
)
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=AppConfig.current())]
assert "bash" in names
assert "ls" in names

View File

@ -314,7 +314,7 @@ class TestLocalSandboxProviderMounts:
)
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
provider = LocalSandboxProvider(app_config=config)
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
@ -336,7 +336,7 @@ class TestLocalSandboxProviderMounts:
)
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
provider = LocalSandboxProvider(app_config=config)
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
@ -360,7 +360,7 @@ class TestLocalSandboxProviderMounts:
)
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
provider = LocalSandboxProvider(app_config=config)
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
@ -384,6 +384,6 @@ class TestLocalSandboxProviderMounts:
)
with patch.object(AppConfig, "current", return_value=config):
provider = LocalSandboxProvider()
provider = LocalSandboxProvider(app_config=config)
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]

View File

@ -6,12 +6,24 @@ from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
def _make_config(**memory_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
with (
patch.object(AppConfig, "current", return_value=_make_config(enabled=True)),
@ -26,7 +38,7 @@ def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
def test_process_queue_forwards_correction_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
queue._queue = [
ConversationContext(
thread_id="thread-1",
@ -52,7 +64,7 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
with (
patch.object(AppConfig, "current", return_value=_make_config(enabled=True)),
@ -67,7 +79,7 @@ def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> No
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
queue._queue = [
ConversationContext(
thread_id="thread-1",

View File

@ -1,3 +1,15 @@
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
"""Tests for user_id propagation through memory queue."""
from unittest.mock import MagicMock, patch
@ -27,7 +39,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id():
q = MemoryUpdateQueue()
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
with patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1
@ -36,7 +48,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue()
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
with patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")

View File

@ -4,6 +4,18 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.routers import memory
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
_TEST_APP_CONFIG = AppConfig(sandbox=SandboxConfig(use="test"))
def _make_app() -> FastAPI:
"""Build a memory-router app pre-populated with a minimal AppConfig."""
app = FastAPI()
app.state.config = _TEST_APP_CONFIG
app.include_router(memory.router)
return app
def _sample_memory(facts: list[dict] | None = None) -> dict:
@ -25,8 +37,7 @@ def _sample_memory(facts: list[dict] | None = None) -> dict:
def test_export_memory_route_returns_current_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
exported_memory = _sample_memory(
facts=[
{
@ -49,8 +60,7 @@ def test_export_memory_route_returns_current_memory() -> None:
def test_import_memory_route_returns_imported_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
imported_memory = _sample_memory(
facts=[
{
@ -73,8 +83,7 @@ def test_import_memory_route_returns_imported_memory() -> None:
def test_export_memory_route_preserves_source_error() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
exported_memory = _sample_memory(
facts=[
{
@ -98,8 +107,7 @@ def test_export_memory_route_preserves_source_error() -> None:
def test_import_memory_route_preserves_source_error() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
imported_memory = _sample_memory(
facts=[
{
@ -123,8 +131,7 @@ def test_import_memory_route_preserves_source_error() -> None:
def test_clear_memory_route_returns_cleared_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
with patch("app.gateway.routers.memory.clear_memory_data", return_value=_sample_memory()):
with TestClient(app) as client:
@ -135,8 +142,7 @@ def test_clear_memory_route_returns_cleared_memory() -> None:
def test_create_memory_fact_route_returns_updated_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
updated_memory = _sample_memory(
facts=[
{
@ -166,8 +172,7 @@ def test_create_memory_fact_route_returns_updated_memory() -> None:
def test_delete_memory_fact_route_returns_updated_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
updated_memory = _sample_memory(
facts=[
{
@ -190,8 +195,7 @@ def test_delete_memory_fact_route_returns_updated_memory() -> None:
def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
with patch("app.gateway.routers.memory.delete_memory_fact", side_effect=KeyError("fact_missing")):
with TestClient(app) as client:
@ -202,8 +206,7 @@ def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
def test_update_memory_fact_route_returns_updated_memory() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
updated_memory = _sample_memory(
facts=[
{
@ -233,8 +236,7 @@ def test_update_memory_fact_route_returns_updated_memory() -> None:
def test_update_memory_fact_route_preserves_omitted_fields() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
updated_memory = _sample_memory(
facts=[
{
@ -269,8 +271,7 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=KeyError("fact_missing")):
with TestClient(app) as client:
@ -288,8 +289,7 @@ def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
def test_update_memory_fact_route_returns_specific_error_for_invalid_confidence() -> None:
app = FastAPI()
app.include_router(memory.router)
app = _make_app()
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=ValueError("confidence")):
with TestClient(app) as client:

View File

@ -1,3 +1,15 @@
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
"""Tests for memory storage providers."""
import threading
@ -60,7 +72,7 @@ class TestFileMemoryStorage:
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path = storage._get_memory_file_path(None)
assert path == tmp_path / "memory.json"
@ -73,14 +85,14 @@ class TestFileMemoryStorage:
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
path = storage._get_memory_file_path("test-agent")
assert path == tmp_path / "agents" / "test-agent" / "memory.json"
@pytest.mark.parametrize("invalid_name", ["", "../etc/passwd", "agent/name", "agent\\name", "agent name", "agent@123", "agent_name"])
def test_validate_agent_name_invalid(self, invalid_name):
"""Should raise ValueError for invalid agent names that don't match the pattern."""
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
with pytest.raises(ValueError, match="Invalid agent name|Agent name must be a non-empty string"):
storage._validate_agent_name(invalid_name)
@ -94,7 +106,7 @@ class TestFileMemoryStorage:
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = storage.load()
assert isinstance(memory, dict)
assert memory["version"] == "1.0"
@ -110,7 +122,7 @@ class TestFileMemoryStorage:
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
result = storage.save(test_memory)
assert result is True
@ -129,7 +141,7 @@ class TestFileMemoryStorage:
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
storage = FileMemoryStorage()
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
# First load
memory1 = storage.load()
assert memory1["facts"][0]["content"] == "initial fact"
@ -157,20 +169,20 @@ class TestGetMemoryStorage:
def test_returns_file_memory_storage_by_default(self):
"""Should return FileMemoryStorage by default."""
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
storage = get_memory_storage()
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
assert isinstance(storage, FileMemoryStorage)
def test_falls_back_to_file_memory_storage_on_error(self):
"""Should fall back to FileMemoryStorage if configured storage fails to load."""
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="non.existent.StorageClass")):
storage = get_memory_storage()
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
assert isinstance(storage, FileMemoryStorage)
def test_returns_singleton_instance(self):
"""Should return the same instance on subsequent calls."""
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
storage1 = get_memory_storage()
storage2 = get_memory_storage()
storage1 = get_memory_storage(_TEST_MEMORY_CONFIG)
storage2 = get_memory_storage(_TEST_MEMORY_CONFIG)
assert storage1 is storage2
def test_get_memory_storage_thread_safety(self):
@ -181,7 +193,7 @@ class TestGetMemoryStorage:
# get_memory_storage is called concurrently from multiple threads while
# AppConfig.get is patched once around thread creation. This verifies
# that the singleton initialization remains thread-safe.
results.append(get_memory_storage())
results.append(get_memory_storage(_TEST_MEMORY_CONFIG))
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
threads = [threading.Thread(target=get_storage) for _ in range(10)]
@ -198,12 +210,12 @@ class TestGetMemoryStorage:
"""Should fall back to FileMemoryStorage if the configured class is not actually a class."""
# Using a built-in function instead of a class
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="os.path.join")):
storage = get_memory_storage()
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
assert isinstance(storage, FileMemoryStorage)
def test_get_memory_storage_non_subclass_fallback(self):
"""Should fall back to FileMemoryStorage if the configured class is not a subclass of MemoryStorage."""
# Using 'dict' as a class that is not a MemoryStorage subclass
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="builtins.dict")):
storage = get_memory_storage()
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
assert isinstance(storage, FileMemoryStorage)

View File

@ -1,3 +1,15 @@
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
"""Tests for per-user memory storage isolation."""
import pytest
from pathlib import Path
@ -21,7 +33,7 @@ def base_dir(tmp_path: Path) -> Path:
@pytest.fixture
def storage() -> FileMemoryStorage:
return FileMemoryStorage()
return FileMemoryStorage(_TEST_MEMORY_CONFIG)
@pytest.fixture(autouse=True)
@ -57,7 +69,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
s.save(memory, user_id="alice")
expected_path = base_dir / "users" / "alice" / "memory.json"
@ -68,7 +80,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory_a = create_empty_memory()
memory_a["user"]["workContext"]["summary"] = "A"
s.save(memory_a, user_id="alice")
@ -85,7 +97,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
s.save(memory, user_id=None)
expected_path = base_dir / "memory.json"
@ -97,7 +109,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
legacy_mem = create_empty_memory()
legacy_mem["user"]["workContext"]["summary"] = "legacy"
@ -116,7 +128,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
memory["user"]["workContext"]["summary"] = "agent scoped"
s.save(memory, "test-agent", user_id="alice")
@ -129,7 +141,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
s.save(memory, user_id="alice")
# After save, cache should have tuple key
@ -141,7 +153,7 @@ class TestUserIsolatedStorage:
paths = Paths(base_dir)
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
s = FileMemoryStorage()
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
memory = create_empty_memory()
memory["user"]["workContext"]["summary"] = "initial"
s.save(memory, user_id="alice")

View File

@ -6,6 +6,17 @@ the in-memory LangGraph Store backend used when database.backend=memory.
from __future__ import annotations
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
from types import SimpleNamespace
import pytest

View File

@ -15,6 +15,18 @@ from deerflow.config.memory_config import MemoryConfig
from deerflow.config.sandbox_config import SandboxConfig
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, object]:
return {
"version": "1.0",
@ -38,7 +50,7 @@ def _memory_config(**overrides: object) -> AppConfig:
def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory(
facts=[
{
@ -65,18 +77,14 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None
{"content": "User likes Python", "category": "preference", "confidence": 0.95},
],
}
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
assert [fact["content"] for fact in result["facts"]] == ["User likes Python"]
assert all(fact["id"] != "fact_remove" for fact in result["facts"])
def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory()
update_data = {
"newFacts": [
@ -85,11 +93,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
{"content": "User works on DeerFlow", "category": "context", "confidence": 0.87},
],
}
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
assert [fact["content"] for fact in result["facts"]] == [
"User prefers dark mode",
@ -100,7 +104,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=2, fact_confidence_threshold=0.7))
current_memory = _make_memory(
facts=[
{
@ -128,11 +132,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
{"content": "User likes noisy logs", "category": "behavior", "confidence": 0.6},
],
}
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=2, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
assert [fact["content"] for fact in result["facts"]] == [
"User likes Python",
@ -143,7 +143,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
def test_apply_updates_preserves_source_error() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory()
update_data = {
"newFacts": [
@ -155,18 +155,14 @@ def test_apply_updates_preserves_source_error() -> None:
}
]
}
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
assert result["facts"][0]["category"] == "correction"
def test_apply_updates_ignores_empty_source_error() -> None:
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory()
update_data = {
"newFacts": [
@ -178,18 +174,14 @@ def test_apply_updates_ignores_empty_source_error() -> None:
}
]
}
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
assert "sourceError" not in result["facts"][0]
def test_clear_memory_data_resets_all_sections() -> None:
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
result = clear_memory_data()
result = clear_memory_data(_TEST_MEMORY_CONFIG)
assert result["version"] == "1.0"
assert result["facts"] == []
@ -223,7 +215,7 @@ def test_delete_memory_fact_removes_only_matching_fact() -> None:
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
):
result = delete_memory_fact("fact_delete")
result = delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_delete")
assert [fact["id"] for fact in result["facts"]] == ["fact_keep"]
@ -233,7 +225,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
):
result = create_memory_fact(
result = create_memory_fact(_TEST_MEMORY_CONFIG,
content=" User prefers concise code reviews. ",
category="preference",
confidence=0.88,
@ -248,7 +240,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
def test_create_memory_fact_rejects_empty_content() -> None:
try:
create_memory_fact(content=" ")
create_memory_fact(_TEST_MEMORY_CONFIG, content=" ")
except ValueError as exc:
assert exc.args == ("content",)
else:
@ -258,7 +250,7 @@ def test_create_memory_fact_rejects_empty_content() -> None:
def test_create_memory_fact_rejects_invalid_confidence() -> None:
for confidence in (-0.1, 1.1, float("nan"), float("inf"), float("-inf")):
try:
create_memory_fact(content="User likes tests", confidence=confidence)
create_memory_fact(_TEST_MEMORY_CONFIG, content="User likes tests", confidence=confidence)
except ValueError as exc:
assert exc.args == ("confidence",)
else:
@ -268,7 +260,7 @@ def test_create_memory_fact_rejects_invalid_confidence() -> None:
def test_delete_memory_fact_raises_for_unknown_id() -> None:
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
try:
delete_memory_fact("fact_missing")
delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_missing")
except KeyError as exc:
assert exc.args == ("fact_missing",)
else:
@ -293,7 +285,7 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
mock_storage.load.return_value = imported_memory
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
result = import_memory_data(imported_memory)
result = import_memory_data(_TEST_MEMORY_CONFIG, imported_memory)
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
mock_storage.load.assert_called_once_with(None, user_id=None)
@ -326,7 +318,7 @@ def test_update_memory_fact_updates_only_matching_fact() -> None:
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
):
result = update_memory_fact(
result = update_memory_fact(_TEST_MEMORY_CONFIG,
fact_id="fact_edit",
content="User prefers spaces",
category="workflow",
@ -359,7 +351,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
):
result = update_memory_fact(
result = update_memory_fact(_TEST_MEMORY_CONFIG,
fact_id="fact_edit",
content="User prefers spaces",
)
@ -372,7 +364,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
def test_update_memory_fact_raises_for_unknown_id() -> None:
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
try:
update_memory_fact(
update_memory_fact(_TEST_MEMORY_CONFIG,
fact_id="fact_missing",
content="User prefers concise code reviews.",
category="preference",
@ -404,7 +396,7 @@ def test_update_memory_fact_rejects_invalid_confidence() -> None:
return_value=current_memory,
):
try:
update_memory_fact(
update_memory_fact(_TEST_MEMORY_CONFIG,
fact_id="fact_edit",
content="User prefers spaces",
confidence=confidence,
@ -521,7 +513,7 @@ class TestUpdateMemoryStructuredResponse:
return model
def test_string_response_parses(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
with (
@ -543,7 +535,7 @@ class TestUpdateMemoryStructuredResponse:
def test_list_content_response_parses(self):
"""LLM response as list-of-blocks should be extracted, not repr'd."""
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
list_content = [{"type": "text", "text": valid_json}]
@ -565,7 +557,7 @@ class TestUpdateMemoryStructuredResponse:
assert result is True
def test_correction_hint_injected_when_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
@ -590,7 +582,7 @@ class TestUpdateMemoryStructuredResponse:
assert "Explicit correction signals were detected" in prompt
def test_correction_hint_empty_when_not_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
@ -619,7 +611,7 @@ class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive."""
def test_duplicate_fact_different_case_not_stored(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory(
facts=[
{
@ -639,18 +631,14 @@ class TestFactDeduplicationCaseInsensitive:
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
],
}
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
# Should still have only 1 fact (duplicate rejected)
assert len(result["facts"]) == 1
assert result["facts"][0]["content"] == "User prefers Python"
def test_unique_fact_different_case_and_content_stored(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
current_memory = _make_memory(
facts=[
{
@ -669,11 +657,7 @@ class TestFactDeduplicationCaseInsensitive:
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
],
}
with patch.object(AppConfig, "current",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
assert len(result["facts"]) == 2
@ -690,7 +674,7 @@ class TestReinforcementHint:
return model
def test_reinforcement_hint_injected_when_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
@ -715,7 +699,7 @@ class TestReinforcementHint:
assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
@ -740,7 +724,7 @@ class TestReinforcementHint:
assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self):
updater = MemoryUpdater()
updater = MemoryUpdater(_TEST_APP_CONFIG)
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)

View File

@ -1,3 +1,15 @@
# --- Phase 2 config-refactor test helper ---
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
# minimal config once and reuse it across call sites.
from deerflow.config.app_config import AppConfig as _TestAppConfig
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
# -------------------------------------------
"""Tests for user_id propagation in memory updater."""
from unittest.mock import MagicMock, patch
@ -8,7 +20,7 @@ def test_get_memory_data_passes_user_id():
mock_storage = MagicMock()
mock_storage.load.return_value = {"version": "1.0"}
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
get_memory_data(user_id="alice")
get_memory_data(_TEST_MEMORY_CONFIG, user_id="alice")
mock_storage.load.assert_called_once_with(None, user_id="alice")
@ -16,7 +28,7 @@ def test_save_memory_passes_user_id():
mock_storage = MagicMock()
mock_storage.save.return_value = True
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
_save_memory_to_file({"version": "1.0"}, user_id="bob")
_save_memory_to_file(_TEST_MEMORY_CONFIG, {"version": "1.0"}, user_id="bob")
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
@ -24,6 +36,6 @@ def test_clear_memory_data_passes_user_id():
mock_storage = MagicMock()
mock_storage.save.return_value = True
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
clear_memory_data(user_id="charlie")
clear_memory_data(_TEST_MEMORY_CONFIG, user_id="charlie")
# Verify save was called with user_id
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"

View File

@ -88,7 +88,7 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
_patch_factory(monkeypatch, cfg)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name=None)
factory_module.create_chat_model(name=None, app_config=AppConfig.current())
# resolve_class is called — if we reach here without ValueError, the correct model was used
assert FakeChatModel.captured_kwargs.get("model") == "alpha"
@ -100,7 +100,7 @@ def test_raises_when_model_not_found(monkeypatch):
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
with pytest.raises(ValueError, match="ghost-model"):
factory_module.create_chat_model(name="ghost-model")
factory_module.create_chat_model(name="ghost-model", app_config=AppConfig.current())
def test_appends_all_tracing_callbacks(monkeypatch):
@ -109,7 +109,7 @@ def test_appends_all_tracing_callbacks(monkeypatch):
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: ["smith-callback", "langfuse-callback"])
FakeChatModel.captured_kwargs = {}
model = factory_module.create_chat_model(name="alpha")
model = factory_module.create_chat_model(name="alpha", app_config=AppConfig.current())
assert model.callbacks == ["smith-callback", "langfuse-callback"]
@ -127,7 +127,7 @@ def test_thinking_enabled_raises_when_not_supported_but_when_thinking_enabled_is
_patch_factory(monkeypatch, cfg)
with pytest.raises(ValueError, match="does not support thinking"):
factory_module.create_chat_model(name="no-think", thinking_enabled=True)
factory_module.create_chat_model(name="no-think", thinking_enabled=True, app_config=AppConfig.current())
def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(monkeypatch):
@ -138,7 +138,7 @@ def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(
_patch_factory(monkeypatch, cfg)
with pytest.raises(ValueError, match="does not support thinking"):
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True)
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True, app_config=AppConfig.current())
def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
@ -147,7 +147,7 @@ def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
_patch_factory(monkeypatch, cfg)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="thinker", thinking_enabled=True)
factory_module.create_chat_model(name="thinker", thinking_enabled=True, app_config=AppConfig.current())
assert FakeChatModel.captured_kwargs.get("temperature") == 1.0
assert FakeChatModel.captured_kwargs.get("max_tokens") == 16000
@ -183,7 +183,7 @@ def test_thinking_disabled_openai_gateway_format(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False)
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False, app_config=AppConfig.current())
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
assert captured.get("reasoning_effort") == "minimal"
@ -216,7 +216,7 @@ def test_thinking_disabled_langchain_anthropic_format(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False)
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False, app_config=AppConfig.current())
assert captured.get("thinking") == {"type": "disabled"}
assert "extra_body" not in captured
@ -238,7 +238,7 @@ def test_thinking_disabled_no_when_thinking_enabled_does_nothing(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="plain", thinking_enabled=False)
factory_module.create_chat_model(name="plain", thinking_enabled=False, app_config=AppConfig.current())
assert "extra_body" not in captured
assert "thinking" not in captured
@ -278,7 +278,7 @@ def test_when_thinking_disabled_takes_precedence_over_hardcoded_disable(monkeypa
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False)
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False, app_config=AppConfig.current())
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
# User overrode the hardcoded "minimal" with "low"
@ -310,7 +310,7 @@ def test_when_thinking_disabled_not_used_when_thinking_enabled(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True)
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True, app_config=AppConfig.current())
# when_thinking_enabled should apply, NOT when_thinking_disabled
assert captured.get("extra_body") == {"thinking": {"type": "enabled"}}
@ -339,7 +339,7 @@ def test_when_thinking_disabled_without_when_thinking_enabled_still_applies(monk
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False)
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False, app_config=AppConfig.current())
# when_thinking_disabled is now gated independently of has_thinking_settings
assert captured.get("reasoning_effort") == "low"
@ -370,7 +370,7 @@ def test_when_thinking_disabled_excluded_from_model_dump(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True)
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True, app_config=AppConfig.current())
# when_thinking_disabled value must NOT appear as a raw key
assert "when_thinking_disabled" not in captured
@ -394,7 +394,7 @@ def test_reasoning_effort_cleared_when_not_supported(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="no-effort", thinking_enabled=False)
factory_module.create_chat_model(name="no-effort", thinking_enabled=False, app_config=AppConfig.current())
assert captured.get("reasoning_effort") is None
@ -422,7 +422,7 @@ def test_reasoning_effort_preserved_when_supported(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="effort-model", thinking_enabled=False)
factory_module.create_chat_model(name="effort-model", thinking_enabled=False, app_config=AppConfig.current())
# When supports_reasoning_effort=True, it should NOT be cleared to None
# The disable path sets it to "minimal"; supports_reasoning_effort=True keeps it
@ -458,7 +458,7 @@ def test_thinking_shortcut_enables_thinking_when_thinking_enabled(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True)
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True, app_config=AppConfig.current())
assert captured.get("thinking") == thinking_settings
@ -488,7 +488,7 @@ def test_thinking_shortcut_disables_thinking_when_thinking_disabled(monkeypatch)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False)
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False, app_config=AppConfig.current())
assert captured.get("thinking") == {"type": "disabled"}
assert "extra_body" not in captured
@ -520,7 +520,7 @@ def test_thinking_shortcut_merges_with_when_thinking_enabled(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="merge-model", thinking_enabled=True)
factory_module.create_chat_model(name="merge-model", thinking_enabled=True, app_config=AppConfig.current())
# Both the thinking shortcut and when_thinking_enabled settings should be applied
assert captured.get("thinking") == thinking_settings
@ -552,7 +552,7 @@ def test_thinking_shortcut_not_leaked_into_model_when_disabled(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="no-leak", thinking_enabled=False)
factory_module.create_chat_model(name="no-leak", thinking_enabled=False, app_config=AppConfig.current())
# The disable path should have set thinking to disabled (not the raw enabled shortcut)
assert captured.get("thinking") == {"type": "disabled"}
@ -590,7 +590,7 @@ def test_openai_compatible_provider_passes_base_url(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="minimax-m2.5")
factory_module.create_chat_model(name="minimax-m2.5", app_config=AppConfig.current())
assert captured.get("model") == "MiniMax-M2.5"
assert captured.get("base_url") == "https://api.minimax.io/v1"
@ -638,11 +638,11 @@ def test_openai_compatible_provider_multiple_models(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
# Create first model
factory_module.create_chat_model(name="minimax-m2.5")
factory_module.create_chat_model(name="minimax-m2.5", app_config=AppConfig.current())
assert captured.get("model") == "MiniMax-M2.5"
# Create second model
factory_module.create_chat_model(name="minimax-m2.5-highspeed")
factory_module.create_chat_model(name="minimax-m2.5-highspeed", app_config=AppConfig.current())
assert captured.get("model") == "MiniMax-M2.5-highspeed"
@ -670,7 +670,7 @@ def test_codex_provider_disables_reasoning_when_thinking_disabled(monkeypatch):
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="codex", thinking_enabled=False)
factory_module.create_chat_model(name="codex", thinking_enabled=False, app_config=AppConfig.current())
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "none"
@ -690,7 +690,7 @@ def test_codex_provider_preserves_explicit_reasoning_effort(monkeypatch):
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high")
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high", app_config=AppConfig.current())
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "high"
@ -710,7 +710,7 @@ def test_codex_provider_defaults_reasoning_effort_to_medium(monkeypatch):
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="codex", thinking_enabled=True)
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=AppConfig.current())
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "medium"
@ -731,7 +731,7 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
FakeChatModel.captured_kwargs = {}
factory_module.create_chat_model(name="codex", thinking_enabled=True)
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=AppConfig.current())
assert "max_tokens" not in FakeChatModel.captured_kwargs
@ -757,7 +757,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False)
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False, app_config=AppConfig.current())
assert captured.get("extra_body") == {"top_k": 20, "chat_template_kwargs": {"thinking": False}}
assert captured.get("reasoning_effort") is None
@ -784,7 +784,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False)
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False, app_config=AppConfig.current())
assert captured.get("extra_body") == {
"top_k": 20,
@ -818,7 +818,7 @@ def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="deepseek")
factory_module.create_chat_model(name="deepseek", app_config=AppConfig.current())
assert captured.get("stream_usage") is True
@ -837,7 +837,7 @@ def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="claude")
factory_module.create_chat_model(name="claude", app_config=AppConfig.current())
assert "stream_usage" not in captured
@ -867,7 +867,7 @@ def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="deepseek")
factory_module.create_chat_model(name="deepseek", app_config=AppConfig.current())
assert captured.get("stream_usage") is False
@ -897,7 +897,7 @@ def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="gpt-5-responses")
factory_module.create_chat_model(name="gpt-5-responses", app_config=AppConfig.current())
assert captured.get("use_responses_api") is True
assert captured.get("output_version") == "responses/v1"
@ -938,7 +938,7 @@ def test_no_duplicate_kwarg_when_reasoning_effort_in_config_and_thinking_disable
_patch_factory(monkeypatch, cfg, model_class=CapturingModel)
# Must not raise TypeError
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False)
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False, app_config=AppConfig.current())
# kwargs (runtime) takes precedence: thinking-disabled path sets reasoning_effort=minimal
assert captured.get("reasoning_effort") == "minimal"

View File

@ -15,6 +15,10 @@ def _make_runtime(tmp_path):
workspace.mkdir()
uploads.mkdir()
outputs.mkdir()
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
return SimpleNamespace(
state={
"sandbox": {"sandbox_id": "local"},
@ -24,7 +28,10 @@ def _make_runtime(tmp_path):
"outputs_path": str(outputs),
},
},
context={"thread_id": "thread-1"},
context=DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
thread_id="thread-1",
),
)

View File

@ -35,6 +35,53 @@ _THREAD_DATA = {
}
def _make_app_config(
*,
skills_container_path: str = "/mnt/skills",
skills_host_path: str | None = None,
mounts=None,
mcp_servers=None,
tool_config_map=None,
) -> SimpleNamespace:
"""Build a lightweight AppConfig stand-in used by tests.
Only the attributes accessed by the helpers under test are populated;
everything else is omitted to keep the fake minimal and explicit.
"""
skills_path = Path(skills_host_path) if skills_host_path is not None else None
skills_cfg = SimpleNamespace(
container_path=skills_container_path,
get_skills_path=lambda: skills_path if skills_path is not None else Path("/nonexistent-skills-root-12345"),
)
sandbox_cfg = SimpleNamespace(mounts=list(mounts) if mounts else [], bash_output_max_chars=20000)
extensions_cfg = SimpleNamespace(mcp_servers=dict(mcp_servers) if mcp_servers else {})
tool_config_map = dict(tool_config_map or {})
return SimpleNamespace(
skills=skills_cfg,
sandbox=sandbox_cfg,
extensions=extensions_cfg,
get_tool_config=lambda name: tool_config_map.get(name),
)
_DEFAULT_APP_CONFIG = _make_app_config()
def _make_ctx(thread_id: str = "thread-1", *, app_config=_DEFAULT_APP_CONFIG, sandbox_key: str | None = None):
"""Build a DeerFlowContext-like object with extra attributes allowed.
``resolve_context`` only checks ``isinstance(ctx, DeerFlowContext)``; for
tests that need additional attributes (``sandbox_key``) we use a subclass
created at runtime.
"""
from deerflow.config.deer_flow_context import DeerFlowContext as _DFC
ctx = _DFC(app_config=app_config, thread_id=thread_id)
if sandbox_key is not None:
object.__setattr__(ctx, "sandbox_key", sandbox_key)
return ctx
# ---------- replace_virtual_path ----------
@ -86,7 +133,7 @@ def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
"""Trailing slash on a virtual path inside a command must be preserved."""
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
@ -95,7 +142,7 @@ def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
def test_mask_local_paths_in_output_hides_host_paths() -> None:
output = "Created: /tmp/deer-flow/threads/t1/user-data/workspace/result.txt"
masked = mask_local_paths_in_output(output, _THREAD_DATA)
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/tmp/deer-flow/threads/t1/user-data" not in masked
assert "/mnt/user-data/workspace/result.txt" in masked
@ -108,7 +155,7 @@ def test_mask_local_paths_in_output_hides_skills_host_paths() -> None:
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
):
output = "Reading: /home/user/deer-flow/skills/public/bootstrap/SKILL.md"
masked = mask_local_paths_in_output(output, _THREAD_DATA)
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/home/user/deer-flow/skills" not in masked
assert "/mnt/skills/public/bootstrap/SKILL.md" in masked
@ -144,12 +191,12 @@ def test_reject_path_traversal_allows_normal_paths() -> None:
def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
with pytest.raises(PermissionError, match="Only paths under"):
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
with pytest.raises(PermissionError, match="configured mount paths"):
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
@ -159,42 +206,41 @@ def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
with pytest.raises(PermissionError, match="Only paths under"):
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA)
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_allows_user_data_paths() -> None:
# Should not raise — user-data paths are always allowed
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA, _DEFAULT_APP_CONFIG)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_allows_user_data_write() -> None:
# read_only=False (default) should still work for user-data paths
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=False)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
def test_validate_local_tool_path_rejects_traversal_in_user_data() -> None:
"""Path traversal via .. in user-data paths must be rejected."""
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_rejects_traversal_in_skills() -> None:
"""Path traversal via .. in skills paths must be rejected."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, read_only=True)
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
def test_validate_local_tool_path_rejects_none_thread_data() -> None:
@ -202,7 +248,7 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
from deerflow.sandbox.exceptions import SandboxRuntimeError
with pytest.raises(SandboxRuntimeError):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None)
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None, _DEFAULT_APP_CONFIG)
# ---------- _resolve_skills_path ----------
@ -210,32 +256,26 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
def test_resolve_skills_path_resolves_correctly() -> None:
"""Skills virtual path should resolve to host path."""
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
):
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
# Force get_skills_path().exists() to be True without touching the FS
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", cfg)
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
def test_resolve_skills_path_resolves_root() -> None:
"""Skills container root should resolve to host skills directory."""
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
):
resolved = _resolve_skills_path("/mnt/skills")
assert resolved == "/home/user/deer-flow/skills"
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
resolved = _resolve_skills_path("/mnt/skills", cfg)
assert resolved == "/home/user/deer-flow/skills"
def test_resolve_skills_path_raises_when_not_configured() -> None:
"""Should raise FileNotFoundError when skills directory is not available."""
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=None),
):
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
# Default app config has no host path configured → _get_skills_host_path returns None
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG)
# ---------- _resolve_and_validate_user_data_path ----------
@ -250,7 +290,7 @@ def test_resolve_and_validate_user_data_path_resolves_correctly(tmp_path: Path)
"uploads_path": str(tmp_path / "uploads"),
"outputs_path": str(tmp_path / "outputs"),
}
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data)
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data, _DEFAULT_APP_CONFIG)
assert resolved == str(workspace / "hello.txt")
@ -265,7 +305,7 @@ def test_resolve_and_validate_user_data_path_blocks_traversal(tmp_path: Path) ->
}
# This path resolves outside the allowed roots
with pytest.raises(PermissionError):
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data)
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data, _DEFAULT_APP_CONFIG)
# ---------- replace_virtual_paths_in_command ----------
@ -278,7 +318,7 @@ def test_replace_virtual_paths_in_command_replaces_skills_paths() -> None:
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
):
cmd = "cat /mnt/skills/public/bootstrap/SKILL.md"
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/mnt/skills" not in result
assert "/home/user/deer-flow/skills/public/bootstrap/SKILL.md" in result
@ -290,7 +330,7 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/skills"),
):
cmd = "cat /mnt/skills/public/SKILL.md > /mnt/user-data/workspace/out.txt"
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/mnt/skills" not in result
assert "/mnt/user-data" not in result
assert "/home/user/skills/public/SKILL.md" in result
@ -302,30 +342,27 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
"""HTTP URLs must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
validate_local_bash_command_paths(
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> None:
@ -333,8 +370,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> No
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths(
"cat /mnt/user-data/workspace/../../etc/passwd",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
@ -343,14 +379,13 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths(
"cat /mnt/skills/../../etc/passwd",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) -> None:
runtime = SimpleNamespace(
state={"sandbox": {"sandbox_id": "local"}, "thread_data": _THREAD_DATA.copy()},
context={"thread_id": "thread-1"},
context=_make_ctx("thread-1"),
)
monkeypatch.setattr(
@ -372,33 +407,32 @@ def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) ->
def test_is_skills_path_recognises_default_prefix() -> None:
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
assert _is_skills_path("/mnt/skills") is True
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md") is True
assert _is_skills_path("/mnt/skills-extra/foo") is False
assert _is_skills_path("/mnt/user-data/workspace") is False
assert _is_skills_path("/mnt/skills", _DEFAULT_APP_CONFIG) is True
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG) is True
assert _is_skills_path("/mnt/skills-extra/foo", _DEFAULT_APP_CONFIG) is False
assert _is_skills_path("/mnt/user-data/workspace", _DEFAULT_APP_CONFIG) is False
def test_validate_local_tool_path_allows_skills_read_only() -> None:
"""read_file / ls should be able to access /mnt/skills paths."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
# Should not raise
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
read_only=True,
)
# Should not raise — default app config uses /mnt/skills as container path
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
_DEFAULT_APP_CONFIG,
read_only=True,
)
def test_validate_local_tool_path_blocks_skills_write() -> None:
"""write_file / str_replace must NOT write to skills paths."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
read_only=False,
)
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
_DEFAULT_APP_CONFIG,
read_only=False,
)
def test_validate_local_bash_command_paths_allows_skills_path() -> None:
@ -406,8 +440,7 @@ def test_validate_local_bash_command_paths_allows_skills_path() -> None:
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
validate_local_bash_command_paths(
"cat /mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_allows_urls() -> None:
@ -415,40 +448,35 @@ def test_validate_local_bash_command_paths_allows_urls() -> None:
# HTTPS URLs
validate_local_bash_command_paths(
"curl -X POST https://example.com/api/v1/risk/check",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
# HTTP URLs
validate_local_bash_command_paths(
"curl http://localhost:8080/health",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
# URLs with query strings
validate_local_bash_command_paths(
"curl https://api.example.com/v2/search?q=test",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
# FTP URLs
validate_local_bash_command_paths(
"curl ftp://ftp.example.com/pub/file.tar.gz",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
# URL mixed with valid virtual path
validate_local_bash_command_paths(
"curl https://example.com/data -o /mnt/user-data/workspace/data.json",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_file_urls() -> None:
"""file:// URLs should be treated as unsafe and blocked."""
with pytest.raises(PermissionError):
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA)
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_file_urls_case_insensitive() -> None:
"""file:// URL detection should be case-insensitive."""
with pytest.raises(PermissionError):
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA)
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -> None:
@ -456,35 +484,36 @@ def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -
with pytest.raises(PermissionError):
validate_local_bash_command_paths(
"curl file:///etc/passwd -o /mnt/user-data/workspace/out.txt",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_still_blocks_other_paths() -> None:
"""Paths outside virtual and system prefixes must still be blocked."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_tool_path_skills_custom_container_path() -> None:
"""Skills with a custom container_path in config should also work."""
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/custom/skills"):
# Should not raise
custom_config = _make_app_config(skills_container_path="/custom/skills")
# Should not raise
validate_local_tool_path(
"/custom/skills/public/my-skill/SKILL.md",
_THREAD_DATA,
custom_config,
read_only=True,
)
# The default /mnt/skills should not match since container path is /custom/skills
with pytest.raises(PermissionError, match="Only paths under"):
validate_local_tool_path(
"/custom/skills/public/my-skill/SKILL.md",
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
custom_config,
read_only=True,
)
# The default /mnt/skills should not match since container path is /custom/skills
with pytest.raises(PermissionError, match="Only paths under"):
validate_local_tool_path(
"/mnt/skills/public/bootstrap/SKILL.md",
_THREAD_DATA,
read_only=True,
)
# ---------- ACP workspace path tests ----------
@ -501,6 +530,7 @@ def test_validate_local_tool_path_allows_acp_workspace_read_only() -> None:
validate_local_tool_path(
"/mnt/acp-workspace/hello_world.py",
_THREAD_DATA,
_DEFAULT_APP_CONFIG,
read_only=True,
)
@ -511,6 +541,7 @@ def test_validate_local_tool_path_blocks_acp_workspace_write() -> None:
validate_local_tool_path(
"/mnt/acp-workspace/hello_world.py",
_THREAD_DATA,
_DEFAULT_APP_CONFIG,
read_only=False,
)
@ -519,8 +550,7 @@ def test_validate_local_bash_command_paths_allows_acp_workspace() -> None:
"""bash commands referencing /mnt/acp-workspace should be allowed."""
validate_local_bash_command_paths(
"cp /mnt/acp-workspace/hello_world.py /mnt/user-data/outputs/hello_world.py",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -> None:
@ -528,8 +558,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths(
"cat /mnt/acp-workspace/../../etc/passwd",
_THREAD_DATA,
)
_THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_resolve_acp_workspace_path_resolves_correctly(tmp_path: Path) -> None:
@ -571,7 +600,7 @@ def test_replace_virtual_paths_in_command_replaces_acp_workspace() -> None:
acp_host = "/home/user/.deer-flow/acp-workspace"
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
cmd = "cp /mnt/acp-workspace/hello.py /mnt/user-data/outputs/hello.py"
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert "/mnt/acp-workspace" not in result
assert f"{acp_host}/hello.py" in result
assert "/tmp/deer-flow/threads/t1/user-data/outputs/hello.py" in result
@ -582,7 +611,7 @@ def test_mask_local_paths_in_output_hides_acp_workspace_host_paths() -> None:
acp_host = "/home/user/.deer-flow/acp-workspace"
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
output = f"Copied: {acp_host}/hello.py"
masked = mask_local_paths_in_output(output, _THREAD_DATA)
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
assert acp_host not in masked
assert "/mnt/acp-workspace/hello.py" in masked
@ -622,7 +651,7 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_app_config(enabled: bool) -> AppConfig:
def _mcp_app_config(enabled: bool) -> AppConfig:
return AppConfig(
sandbox=SandboxConfig(use="test"),
extensions=ExtensionsConfig(
@ -636,19 +665,19 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
),
)
with patch.object(AppConfig, "current", return_value=_make_app_config(True)):
# Should not raise - MCP filesystem paths are allowed
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA)
enabled_cfg = _mcp_app_config(True)
# Should not raise - MCP filesystem paths are allowed
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, enabled_cfg)
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA, enabled_cfg)
# Path traversal should still be blocked
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA)
# Path traversal should still be blocked
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA, enabled_cfg)
# Disabled servers should not expose paths
with patch.object(AppConfig, "current", return_value=_make_app_config(False)):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
disabled_cfg = _mcp_app_config(False)
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, disabled_cfg)
# ---------- Custom mount path tests ----------
@ -666,12 +695,12 @@ def _mock_custom_mounts():
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
assert _is_custom_mount_path("/mnt/code-read") is True
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
assert _is_custom_mount_path("/mnt/data") is True
assert _is_custom_mount_path("/mnt/data/file.txt") is True
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
assert _is_custom_mount_path("/mnt/other") is False
assert _is_custom_mount_path("/mnt/code-read", _DEFAULT_APP_CONFIG) is True
assert _is_custom_mount_path("/mnt/code-read/src/main.py", _DEFAULT_APP_CONFIG) is True
assert _is_custom_mount_path("/mnt/data", _DEFAULT_APP_CONFIG) is True
assert _is_custom_mount_path("/mnt/data/file.txt", _DEFAULT_APP_CONFIG) is True
assert _is_custom_mount_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG) is False
assert _is_custom_mount_path("/mnt/other", _DEFAULT_APP_CONFIG) is False
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
@ -682,7 +711,7 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
mount = _get_custom_mount_for_path("/mnt/code/file.py")
mount = _get_custom_mount_for_path("/mnt/code/file.py", _DEFAULT_APP_CONFIG)
assert mount is not None
assert mount.container_path == "/mnt/code"
@ -690,90 +719,72 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
"""read_file / ls should be able to access custom mount paths."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
"""write_file / str_replace must NOT write to read-only custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
"""write_file / str_replace should succeed on writable custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
"""Path traversal via .. in custom mount paths must be rejected."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
"""bash commands referencing custom mount paths should be allowed."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG)
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
"""Bash commands with traversal in custom mount paths should be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
"""Paths not matching any custom mount should still be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
"""_get_custom_mounts should cache after first successful load."""
# Clear any existing cache
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
# Use real directories so host_path.exists() filtering passes
def test_get_custom_mounts_reads_from_app_config(tmp_path) -> None:
"""_get_custom_mounts should read directly from the supplied AppConfig."""
dir_a = tmp_path / "code-read"
dir_a.mkdir()
dir_b = tmp_path / "data"
dir_b.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
from deerflow.config.sandbox_config import VolumeMountConfig
mounts = [
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch.object(AppConfig, "current", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 2
# After caching, should return cached value even without mock
assert hasattr(_get_custom_mounts, "_cached")
assert len(_get_custom_mounts()) == 2
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
cfg = _make_app_config(mounts=mounts)
result = _get_custom_mounts(cfg)
assert len(result) == 2
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
def test_get_custom_mounts_filters_nonexistent_host_path(tmp_path) -> None:
"""_get_custom_mounts should only return mounts whose host_path exists."""
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
from deerflow.config.sandbox_config import VolumeMountConfig
existing_dir = tmp_path / "existing"
existing_dir.mkdir()
@ -782,22 +793,16 @@ def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path)
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch.object(AppConfig, "current", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
cfg = _make_app_config(mounts=mounts)
result = _get_custom_mounts(cfg)
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG)
assert mount is None
@ -828,8 +833,8 @@ def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) ->
sandbox = SharedSandbox()
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
]
failures: list[BaseException] = []
@ -904,14 +909,14 @@ def test_str_replace_parallel_updates_in_isolated_sandboxes_should_not_share_pat
"sandbox-b": IsolatedSandbox("sandbox-b", shared_state),
}
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1", "sandbox_key": "sandbox-a"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-2", "sandbox_key": "sandbox-b"}, config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1", sandbox_key="sandbox-a"), config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-2", sandbox_key="sandbox-b"), config={}),
]
failures: list[BaseException] = []
monkeypatch.setattr(
"deerflow.sandbox.tools.ensure_sandbox_initialized",
lambda runtime: sandboxes[runtime.context["sandbox_key"]],
lambda runtime: sandboxes[runtime.context.sandbox_key],
)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
@ -971,8 +976,8 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
sandbox = SharedSandbox()
runtimes = [
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
]
failures: list[BaseException] = []

View File

@ -12,7 +12,7 @@ async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch):
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
result = await scan_skill_content(config, "---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
assert result.decision == "block"
assert "manual review required" in result.reason

View File

@ -11,9 +11,9 @@ from deerflow.config.sandbox_config import SandboxConfig
skill_manage_module = importlib.import_module("deerflow.tools.skill_manage_tool")
def _make_context(thread_id: str) -> DeerFlowContext:
def _make_context(thread_id: str, app_config: object | None = None) -> DeerFlowContext:
return DeerFlowContext(
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
app_config=app_config if app_config is not None else AppConfig(sandbox=SandboxConfig(use="test")),
thread_id=thread_id,
)
@ -37,13 +37,13 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
refresh_calls = []
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
result = anyio.run(
skill_manage_module.skill_manage_tool.coroutine,
@ -78,13 +78,13 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
async def _refresh():
async def _refresh(*a, **k):
return None
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
content = _skill_content("demo-skill", "Demo skill") + "\nRepeated: Demo skill\n"
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", content)
@ -116,7 +116,7 @@ def test_skill_manage_rejects_public_skill_patch(monkeypatch, tmp_path):
)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
runtime = SimpleNamespace(context=_make_context(""), config={"configurable": {}})
runtime = SimpleNamespace(context=_make_context("", config), config={"configurable": {}})
with pytest.raises(ValueError, match="built-in skill"):
anyio.run(
@ -140,13 +140,13 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
refresh_calls = []
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context=_make_context("thread-sync"), config={"configurable": {"thread_id": "thread-sync"}})
runtime = SimpleNamespace(context=_make_context("thread-sync", config), config={"configurable": {"thread_id": "thread-sync"}})
result = skill_manage_module.skill_manage_tool.func(
runtime=runtime,
action="create",
@ -166,13 +166,13 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
async def _refresh():
async def _refresh(*a, **k):
return None
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", _skill_content("demo-skill"))
with pytest.raises(ValueError, match="parent-directory traversal|selected support directory"):

View File

@ -50,12 +50,13 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
refresh_calls = []
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
with TestClient(app) as client:
@ -96,12 +97,12 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
get_skill_history_file("demo-skill").write_text(
get_skill_history_file("demo-skill", config).write_text(
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
encoding="utf-8",
)
async def _refresh():
async def _refresh(*a, **k):
return None
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
@ -114,6 +115,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
with TestClient(app) as client:
@ -140,12 +142,13 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
refresh_calls = []
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
app = FastAPI()
app.state.config = config
app.include_router(skills_router.router)
with TestClient(app) as client:
@ -169,13 +172,13 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
enabled_state = {"value": True}
refresh_calls = []
def _load_skills(*, enabled_only: bool):
def _load_skills(*a, enabled_only: bool = False, **k):
skill = _make_skill("demo-skill", enabled=enabled_state["value"])
if enabled_only and not skill.enabled:
return []
return [skill]
async def _refresh():
async def _refresh(*a, **k):
refresh_calls.append("refresh")
enabled_state["value"] = False
@ -183,7 +186,6 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
monkeypatch.setattr("app.gateway.routers.skills.load_skills", _load_skills)
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: _app_cfg))
monkeypatch.setattr(AppConfig, "init", staticmethod(lambda _cfg: None))
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda: _app_cfg))
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)

View File

@ -27,7 +27,7 @@ def test_load_skills_discovers_nested_skills_and_sets_container_paths(tmp_path:
_write_skill(skills_root / "public" / "parent" / "child-skill", "child-skill", "Child skill")
_write_skill(skills_root / "custom" / "team" / "helper", "team-helper", "Team helper")
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
skills = load_skills(skills_path=skills_root, enabled_only=False)
by_name = {skill.name: skill for skill in skills}
assert {"root-skill", "child-skill", "team-helper"} <= set(by_name)
@ -57,7 +57,7 @@ def test_load_skills_skips_hidden_directories(tmp_path: Path):
"Hidden skill",
)
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
skills = load_skills(skills_path=skills_root, enabled_only=False)
names = {skill.name for skill in skills}
assert "ok-skill" in names
@ -69,7 +69,7 @@ def test_load_skills_prefers_custom_over_public_with_same_name(tmp_path: Path):
_write_skill(skills_root / "public" / "shared-skill", "shared-skill", "Public version")
_write_skill(skills_root / "custom" / "shared-skill", "shared-skill", "Custom version")
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
skills = load_skills(skills_path=skills_root, enabled_only=False)
shared = next(skill for skill in skills if skill.name == "shared-skill")
assert shared.category == "custom"

View File

@ -6,6 +6,7 @@ import re
import anyio
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
# ---------------------------------------------------------------------------
@ -331,6 +332,9 @@ async def test_concurrent_tasks_end_sentinel():
@pytest.mark.anyio
async def test_make_stream_bridge_defaults():
"""make_stream_bridge() with no config yields a MemoryStreamBridge."""
async with make_stream_bridge() as bridge:
"""make_stream_bridge with a config lacking stream_bridge yields a MemoryStreamBridge."""
from deerflow.config.sandbox_config import SandboxConfig
config = AppConfig(sandbox=SandboxConfig(use="test"))
async with make_stream_bridge(config) as bridge:
assert isinstance(bridge, MemoryStreamBridge)

View File

@ -21,6 +21,8 @@ from unittest.mock import MagicMock, patch
import pytest
_TEST_APP_CONFIG = MagicMock(name="TestAppConfig")
# Module names that need to be mocked to break circular imports
_MOCKED_MODULE_NAMES = [
"deerflow.agents",
@ -203,7 +205,7 @@ class TestAsyncExecutionPath:
config=base_config,
tools=[],
thread_id="test-thread",
trace_id="test-trace",
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -232,7 +234,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -259,7 +261,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -285,7 +287,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -306,7 +308,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -327,7 +329,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -348,7 +350,7 @@ class TestAsyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -384,7 +386,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -419,7 +421,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -456,7 +458,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -477,7 +479,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_aexecute") as mock_aexecute:
@ -511,7 +513,7 @@ class TestSyncExecutionPath:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -565,7 +567,7 @@ class TestAsyncToolSupport:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -602,7 +604,7 @@ class TestAsyncToolSupport:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -648,7 +650,7 @@ class TestThreadSafety:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id=f"thread-{task_id}",
thread_id=f"thread-{task_id}", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -858,7 +860,7 @@ class TestCooperativeCancellation:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -898,7 +900,7 @@ class TestCooperativeCancellation:
executor = SubagentExecutor(
config=base_config,
tools=[],
thread_id="test-thread",
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
)
with patch.object(executor, "_create_agent", return_value=mock_agent):
@ -977,7 +979,7 @@ class TestCooperativeCancellation:
config=short_config,
tools=[],
thread_id="test-thread",
trace_id="test-trace",
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
)
# Wrap _scheduler_pool.submit so we know when run_task finishes

View File

@ -1,29 +1,35 @@
"""Tests for subagent availability and prompt exposure under local bash hardening."""
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.subagents import registry as registry_module
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: False)
def _config() -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"))
names = registry_module.get_available_subagent_names()
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: False)
names = registry_module.get_available_subagent_names(_config())
assert names == ["general-purpose"]
def test_get_available_subagent_names_keeps_bash_when_allowed(monkeypatch) -> None:
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: True)
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: True)
names = registry_module.get_available_subagent_names()
names = registry_module.get_available_subagent_names(_config())
assert names == ["general-purpose", "bash"]
def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch) -> None:
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose"])
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
section = prompt_module._build_subagent_section(3)
section = prompt_module._build_subagent_section(3, _config())
assert "Not available in the current sandbox configuration" in section
assert 'bash("npm test")' not in section
@ -32,9 +38,9 @@ def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch
def test_build_subagent_section_includes_bash_when_available(monkeypatch) -> None:
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose", "bash"])
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose", "bash"])
section = prompt_module._build_subagent_section(3)
section = prompt_module._build_subagent_section(3, _config())
assert "For command execution (git, build, test, deploy operations)" in section
assert 'bash("npm test")' in section

View File

@ -93,11 +93,11 @@ class _DummyScheduledTask:
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: None)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
result = _run_task_tool(
runtime=None,
runtime=_make_runtime(),
description="执行任务",
prompt="do work",
subagent_type="general-purpose",
@ -108,7 +108,7 @@ def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
def test_task_tool_rejects_bash_subagent_when_host_bash_disabled(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: _make_subagent_config())
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: _make_subagent_config())
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda *a, **k: False)
result = _run_task_tool(
@ -152,9 +152,9 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix")
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "Skills Appendix")
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
# task_tool lazily imports from deerflow.tools at call time, so patch that module-level function.
@ -177,7 +177,9 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
assert captured["executor_kwargs"]["config"].max_turns == 7
assert "Skills Appendix" in captured["executor_kwargs"]["config"].system_prompt
get_available_tools.assert_called_once_with(model_name="ark-model", subagent_enabled=False)
from unittest.mock import ANY
get_available_tools.assert_called_once_with(model_name="ark-model", subagent_enabled=False, app_config=ANY)
event_types = [e["type"] for e in events]
assert event_types == ["task_started", "task_running", "task_running", "task_completed"]
@ -194,12 +196,12 @@ def test_task_tool_returns_failed_message(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@ -228,12 +230,12 @@ def test_task_tool_returns_timed_out_message(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@ -264,12 +266,12 @@ def test_task_tool_polling_safety_timeout(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@ -300,12 +302,12 @@ def test_cleanup_called_on_completed(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
lambda *a, **k: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@ -340,12 +342,12 @@ def test_cleanup_called_on_failed(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"),
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="error"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@ -380,12 +382,12 @@ def test_cleanup_called_on_timed_out(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@ -427,12 +429,12 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
@ -480,8 +482,8 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
@ -531,12 +533,12 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
@ -586,12 +588,12 @@ def test_cancellation_calls_request_cancel(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
@ -644,9 +646,9 @@ def test_task_tool_returns_cancelled_message(monkeypatch):
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])

View File

@ -77,11 +77,13 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="好的,先确认需求"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=12)))
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=12))))
title = result["title"]
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
from unittest.mock import ANY
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False, app_config=ANY)
model.ainvoke.assert_awaited_once()
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
@ -97,7 +99,7 @@ class TestTitleMiddlewareCoreLogic:
]
}
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
title = result["title"]
assert title == "请帮我总结这段代码"
@ -114,7 +116,7 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="收到"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
title = result["title"]
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.

View File

@ -270,27 +270,27 @@ class TestDeferredToolsPromptSection:
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
# tool_search.enabled defaults to False
section = get_deferred_tools_prompt_section()
section = get_deferred_tools_prompt_section(AppConfig.current())
assert section == ""
def test_empty_when_enabled_but_no_registry(self, monkeypatch):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
section = get_deferred_tools_prompt_section()
section = get_deferred_tools_prompt_section(AppConfig.current())
assert section == ""
def test_empty_when_enabled_but_empty_registry(self, monkeypatch):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
set_deferred_registry(DeferredToolRegistry())
section = get_deferred_tools_prompt_section()
section = get_deferred_tools_prompt_section(AppConfig.current())
assert section == ""
def test_lists_tool_names(self, registry, monkeypatch):
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
set_deferred_registry(registry)
section = get_deferred_tools_prompt_section()
section = get_deferred_tools_prompt_section(AppConfig.current())
assert "<available-deferred-tools>" in section
assert "</available-deferred-tools>" in section
assert "github_create_issue" in section

5129
backend/uv.lock generated

File diff suppressed because it is too large Load Diff

301
docs/CONFIG_DESIGN.zh.md Normal file
View File

@ -0,0 +1,301 @@
# DeerFlow 配置系统设计
> 对应实现:[PR #2271](https://github.com/bytedance/deer-flow/pull/2271) · RFC [#1811](https://github.com/bytedance/deer-flow/issues/1811) · 归档 spec[config-refactor-design](./plans/2026-04-12-config-refactor-design.md)
## 1. 为什么要重构
重构前的 `deerflow/config/` 有三个结构性问题,凑在一起就是"全局可变状态 + 副作用耦合"的经典反模式:
| 问题 | 具体表现 |
|------|----------|
| 双重真相 | 每个 sub-config 同时是 `AppConfig` 字段**和**模块级全局(`_memory_config` / `_title_config`。consumer 不知道该信哪个 |
| 副作用耦合 | `AppConfig.from_file()` 顺便 mutate 8 个 sub-module 的 globals通过 `load_*_from_dict()` |
| 隔离不完整 | 原有的 `ContextVar` 只罩住 `AppConfig` 本体8 个 sub-config globals 漏在外面 |
从类型论视角看config 本应是一个**纯值对象value object**——构造一次、不变、可复制——但上面这套设计让它变成了"带全局状态的活对象",于是 test mutation、async 边界、热更新都会互相污染。
## 2. 核心设计原则
> **Config is a value object, not live shared state.**
> 构造一次,不可变,没有 reload。新 config = 新对象 + 重建 agent。
这一条原则推导出后面所有决策:
- 全部 config model `frozen=True` → 非法状态不可表示
- `from_file()` 是纯函数 → 无副作用
- 没有 "热加载"语义 → 改变配置等于"拿到新对象",由调用方决定要不要换进程全局
## 3. 四层分层
```mermaid
graph TB
subgraph L1 ["第 1 层 数据模型 — 冻结的 ADT"]
direction LR
AppConfig["AppConfig frozen=True"]
Sub["MemoryConfig TitleConfig SummarizationConfig ... 全部 frozen"]
AppConfig --> Sub
end
subgraph L2 ["第 2 层 Lifecycle — AppConfig.current"]
direction LR
Override["_override ContextVar per-context"]
Global["_global ClassVar process-singleton"]
Auto["auto-load from file with warning"]
Override --> Global
Global --> Auto
end
subgraph L3 ["第 3 层 Per-invocation context — DeerFlowContext"]
direction LR
Ctx["frozen dataclass app_config thread_id agent_name"]
Resolve["resolve_context legacy bridge"]
Ctx --> Resolve
end
subgraph L4 ["第 4 层 访问模式 — 按 caller 类型分流"]
direction LR
Typed["typed middleware runtime.context.app_config.xxx"]
Legacy["dict-legacy resolve_context runtime"]
NonAgent["非 agent 路径 AppConfig.current"]
end
L1 --> L2
L2 --> L3
L3 --> L4
classDef morandiBlue fill:#B5C4D1,stroke:#6A7A8C,color:#2E3A47
classDef morandiGreen fill:#C4D1B5,stroke:#7A8C6A,color:#2E3A47
classDef morandiPurple fill:#C9BED1,stroke:#7E6A8C,color:#2E3A47
classDef morandiGrey fill:#CFCFCF,stroke:#7A7A7A,color:#2E3A47
class L1 morandiBlue
class L2 morandiGreen
class L3 morandiPurple
class L4 morandiGrey
```
### 3.1 第 1 层:冻结的 ADT
所有 config model 都是 Pydantic `frozen=True`
```python
class MemoryConfig(BaseModel):
model_config = ConfigDict(frozen=True)
enabled: bool = True
storage_path: str | None = None
...
class AppConfig(BaseModel):
model_config = ConfigDict(extra="allow", frozen=True)
memory: MemoryConfig
title: TitleConfig
...
```
改 config 用 copy-on-write
```python
new_config = config.model_copy(update={"memory": new_memory_config})
```
**从类型论视角**:这就是个 product typerecord所有字段组合起来才是一个完整的 `AppConfig`。冻结意味着 `AppConfig` 是**指称透明**的——同样的输入永远拿到同样的对象。
### 3.2 第 2 层Lifecycle — `AppConfig.current()`
这层是整个设计最值得讲的一块。它不是一个简单的单 `ContextVar`,而是**三层 fallback**
```python
class AppConfig(BaseModel):
...
# 进程级单例。GIL 下原子指针交换,无需锁
_global: ClassVar[AppConfig | None] = None
# Per-context override用于测试隔离和多 client
_override: ClassVar[ContextVar[AppConfig]] = ContextVar("deerflow_app_config_override")
@classmethod
def init(cls, config: AppConfig) -> None:
"""设置进程全局。对所有后续 async task 可见"""
cls._global = config
@classmethod
def set_override(cls, config: AppConfig) -> Token[AppConfig]:
"""Per-context 覆盖。返回 Token 给 reset_override()"""
return cls._override.set(config)
@classmethod
def reset_override(cls, token: Token[AppConfig]) -> None:
cls._override.reset(token)
@classmethod
def current(cls) -> AppConfig:
"""优先级per-context override > 进程全局 > 自动从文件加载warning"""
try:
return cls._override.get()
except LookupError:
pass
if cls._global is not None:
return cls._global
logger.warning("AppConfig.current() called before init(); auto-loading from file. ...")
config = cls.from_file()
cls._global = config
return config
```
**为什么是三层,不是一层?**
| 原因 | 解释 |
|------|------|
| 单 ContextVar 行不通 | Gateway 收到 `PUT /mcp/config` reload config下一个请求在**全新的 async context** 里跑——ContextVar 的值传不过去。只能用进程级变量 |
| 保留 ContextVar override | 测试需要 per-test scope config`Token`-based reset 保证干净恢复。多 client 场景如果真出现也能靠它 |
| Auto-load fallback | 有些 call site 历史上没调 `init()`内部脚本、import-time 触发的测试)。加 warning 保证信号不丢,但不硬崩 |
**Scala 视角的映射**
- `_global` = 进程级 `var`,脏,但别无选择
- `_override` = `Option[ContextVar]` 形式的 reader monad 层
- `current()` = fallback chain `override.orElse(global).orElse(autoLoad)`,和 `Option.orElse` 思路一致
**为什么 `_global` 没加锁?**
因为读和写都是单个指针赋值assignment of class attribute在 CPython 的 GIL 下是原子的。如果将来改成 read-modify-write比如 "如果没 init 就 init 成 X"),再加 `threading.Lock`。现在不加是因为——不需要。
### 3.3 第 3 层:`DeerFlowContext` — per-invocation typed context
```python
# deerflow/config/deer_flow_context.py
@dataclass(frozen=True)
class DeerFlowContext:
"""Typed, immutable, per-invocation context injected via LangGraph Runtime"""
app_config: AppConfig
thread_id: str
agent_name: str | None = None
```
为什么不把 `thread_id` 也放进 `AppConfig`
- `AppConfig` 是**配置**——进程启动时确定,所有请求共享
- `thread_id` 是**每次调用变的运行时身份**——必须 per-invocation
两者是不同的 category混在一起就是把静态配置和动态 identity 耦合。
**注入路径**
```python
# Gateway worker主路径
deer_flow_context = DeerFlowContext(
app_config=AppConfig.current(),
thread_id=thread_id,
)
agent.astream(input, config=config, context=deer_flow_context)
# DeerFlowClient
AppConfig.init(AppConfig.from_file(config_path))
context = DeerFlowContext(app_config=AppConfig.current(), thread_id=thread_id)
agent.stream(input, config=config, context=context)
```
LangGraph 的 `Runtime` 会把 `context=...` 的值注入到 `Runtime[DeerFlowContext].context` 里。Middleware 拿到的就是 typed 的 `DeerFlowContext`
**不进 context 的东西**`sandbox_id`——它是 mid-execution 才 acquire 的**可变运行时状态**,正确的归宿是 `ThreadState.sandbox`state channel有 reducer不是 context。原先 `sandbox/tools.py` 里 3 处 `runtime.context["sandbox_id"] = ...` 的写法全部删除。
### 3.4 第 4 层:访问模式按 caller 类型分流
三种 caller三种模式
| Caller 类型 | 访问模式 | 例子 |
|-------------|----------|------|
| Typed middleware签名写 `Runtime[DeerFlowContext]` | `runtime.context.app_config.xxx` 直读,无包装 | `memory_middleware` / `title_middleware` / `thread_data_middleware` 等 |
| 可能遇到 dict context 的 tool | `resolve_context(runtime).xxx` | `sandbox/tools.py`dict-legacy 路径)/ `task_tool.py`bash subagent gate |
| 非 agent 路径Gateway router、CLI、factory | `AppConfig.current().xxx` | `app/gateway/routers/*` / `reset_admin.py` / `models/factory.py` |
**关键简化**commit `a934a822`):原本所有 middleware 都走 `resolve_context()`,后来发现既然签名已经是 `Runtime[DeerFlowContext]`,包装就是冗余防御,直接 `runtime.context.app_config.xxx` 就行。同时也把 `title_middleware` 里每个 helper 的 `title_config=None` fallback 都删掉了——**required parameter 不给 default**,让类型系统强制 caller 传对。
这对应 Scala / FP 的两个信条:
- **让非法状态不可表示**`Option[TitleConfig]` 改成 `TitleConfig` required
- **Let-it-crash**config 解析失败是真 bugsurface 出来比吞掉退化更好)
## 4. `resolve_context()` 的三种分支
`resolve_context()` 自己还在,处理三种 runtime.context 形状:
```python
def resolve_context(runtime: Any) -> DeerFlowContext:
ctx = getattr(runtime, "context", None)
# 1. typed 路径Gateway、Client— 直接返回
if isinstance(ctx, DeerFlowContext):
return ctx
# 2. dict-legacy 路径(老测试、第三方 invoke— 桥接
if isinstance(ctx, dict):
thread_id = ctx.get("thread_id", "")
if not thread_id:
logger.warning("...empty thread_id...")
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=thread_id,
agent_name=ctx.get("agent_name"),
)
# 3. 完全没 context — fall back 到 LangGraph configurable
cfg = get_config().get("configurable", {})
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=cfg.get("thread_id", ""),
agent_name=cfg.get("agent_name"),
)
```
空 thread_id 会 warn不会硬崩——在这里 warn 比 crash 合理,因为 `thread_id` 缺失只影响文件路径(落到空字符串目录),不会让整个 agent 跑崩。
## 5. Gateway config 热更新流程
历史上 Gateway 用 `reload_*_config()` 带 mtime 检测。现在改成:
```
写 extensions_config.json → AppConfig.init(AppConfig.from_file()) → 下一个请求看到新值
```
**没有**mtime 检测、自动刷新、`reload_*()` 函数。
哲学很简单:**结构性变化模型、tools、middleware 链)需要重建 agent运行时变化`memory.enabled` 这种 flag下一次 invocation 从 `AppConfig.current()` 取值就自动生效**。不需要给 config 做"活对象"语义。
## 6. 从原计划的分歧
三处关键分歧(详情见 [归档 spec §7](./plans/2026-04-12-config-refactor-design.md#7-divergence-from-original-plan)
| 分歧 | 原计划 | Shipped | 原因 |
|------|--------|---------|------|
| Lifecycle 存储 | 单 ContextVar`ConfigNotInitializedError` 硬崩 | 3 层 fallbackauto-load + warning | ContextVar 跨 async 边界传不过去 |
| 模块位置 | 新建 `context.py` | Lifecycle 放在 `AppConfig` 自身 classmethod | 减一层模块耦合 |
| Middleware 访问 | 处处 `resolve_context()` | typed middleware 直读 `runtime.context.xxx` | 类型收紧后防御性包装是 noise |
## 7. 从 Scala / Actor 视角的几点观察
- **`AppConfig` 就是个 case class / ADT**。`frozen=True` 相当于 Scala 的 final case class构造完就不动。改动靠 `model_copy(update=…)`,对应 Scala 的 `copy(…)`
- **`DeerFlowContext` 是 typed reader**。Middleware 接收 `Runtime[DeerFlowContext]`,本质是 `Kleisli[DeerFlowContext, State, Result]`——依赖注入,类型化。比 `RunnableConfig.configurable: dict[str, Any]` 强太多。
- **`resolve_context()` 是适配层**。存在是因为有三种不同形状的上游输入;在纯 FP 眼里这是个 `X => DeerFlowContext` 的 total function通过 pattern match 三种 case 把世界收敛回 typed 的那条路径。
- **Let-it-crash 的体现**commit `a934a822` 干掉 middleware 里 `try/except resolve_context(...)`,干掉 `TitleConfig | None` 的 defensive fallback。Config 解析失败就让它抛出去,别吞成"degraded mode"——actor supervision 会处理,吞错反而藏 bug。
- **进程 global 的妥协**`_global: ClassVar` 是这套设计里唯一违背纯值的地方。但在 Python async + HTTP server 的语境里,你没别的办法跨 request 把"新 config"传给所有 task。承认妥协、限制范围只在 lifecycle 层一个变量)、周边全部 immutable——这就是工程意义上的"合理妥协"。
## 8. Cheat sheet
想访问 config怎么办按你写代码的位置看
| 我在写什么 | 用什么 |
|------------|--------|
| Typed middleware签名 `Runtime[DeerFlowContext]` | `runtime.context.app_config.xxx` |
| Typed tool`ToolRuntime[DeerFlowContext]` | `runtime.context.xxx` |
| 可能被老调用方以 dict context 调到的 tool | `resolve_context(runtime).xxx` |
| Gateway router、CLI、factory、测试 helper | `AppConfig.current().xxx` |
| 启动时初始化 | `AppConfig.init(AppConfig.from_file(path))` |
| 测试里想临时改 config | `token = AppConfig.set_override(cfg)` / `AppConfig.reset_override(token)` |
| Gateway 写完新 `extensions_config.json` 之后 | `AppConfig.init(AppConfig.from_file())`,然后让 agent 重建(如果结构变了) |
不要:
- ~~`get_memory_config()` / `get_title_config()` 等旧 getter~~(已删)
- ~~`reload_app_config()` / `reset_app_config()`~~(已删)
- ~~`_memory_config` 等模块级 global~~(已删)
- ~~`runtime.context["sandbox_id"] = ...`~~(走 `runtime.state["sandbox"]`
- ~~防御性 `try/except resolve_context(...)`~~(让它崩)