mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(config): Phase 2 — eliminate AppConfig.current() ambient lookup
Finish Phase 2 of the config refactor: production code no longer calls AppConfig.current() anywhere. AppConfig now flows as an explicit parameter down every consumer lane. Call-site migrations -------------------- - Memory subsystem (queue/updater/storage): MemoryConfig captured at enqueue time so the Timer closure survives the ContextVar boundary. - Sandbox layer: tools.py, security.py, sandbox_provider.py, local_sandbox_provider, aio_sandbox_provider all take app_config explicitly. Module-level caching in tools.py's path helpers is removed — pure parameter flow. - Skills layer: manager.py + loader.py + lead_agent.prompt cache refresh all thread app_config; cache worker closes over it. - Community tools (tavily, jina, firecrawl, exa, ddg, image_search, infoquest, aio_sandbox): read runtime.context.app_config. - Subagents registry: get_subagent_config / list_subagents / get_available_subagent_names require app_config. - Runtime worker: requires RunContext.app_config; no fallback. - Gateway routers (uploads, skills): add Depends(get_config). - Channels feishu: uses AppConfig.from_file() (pure) at its sync boundary. - LangGraph Server bootstrap (make_lead_agent): falls back to AppConfig.from_file() — pure load, not ambient lookup. Context resolution ------------------ - resolve_context(runtime) now raises on non-DeerFlowContext runtime.context. Every entry point attaches typed context; dict/None shapes are rejected loudly instead of being papered over with an ambient AppConfig lookup. AppConfig lifecycle ------------------- - AppConfig.current() kept as a deprecated slot that raises RuntimeError, purely so legacy tests that still run `patch.object(AppConfig, "current")` don't trip AttributeError at teardown. Production never calls it. - conftest autouse fixture no longer monkey-patches `current` — it only stubs `from_file()` so tests don't need a real config.yaml. Design refs ----------- - docs/plans/2026-04-12-config-refactor-plan.md (Phase 2: P2-6..P2-10) - docs/plans/2026-04-12-config-refactor-design.md §8 All 2338 non-e2e tests pass. Zero AppConfig.current() call sites remain in backend/packages or backend/app (docstrings in deps.py excepted).
This commit is contained in:
parent
0e5ff6f431
commit
84dccef230
@ -375,7 +375,9 @@ class FeishuChannel(Channel):
|
||||
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
|
||||
|
||||
try:
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
sandbox_provider = get_sandbox_provider(AppConfig.from_file())
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
if sandbox_id != "local":
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
|
||||
@ -67,17 +67,8 @@ class ChannelService:
|
||||
self._running = False
|
||||
|
||||
@classmethod
|
||||
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
||||
"""Create a ChannelService from the application config.
|
||||
|
||||
Pass ``app_config`` explicitly when available (e.g. from Gateway
|
||||
startup). Falls back to ``AppConfig.current()`` for legacy callers;
|
||||
that fallback is removed in Phase 2 task P2-10.
|
||||
"""
|
||||
if app_config is None:
|
||||
from deerflow.config.app_config import AppConfig as _AppConfig
|
||||
|
||||
app_config = _AppConfig.current()
|
||||
def from_app_config(cls, app_config: AppConfig) -> ChannelService:
|
||||
"""Create a ChannelService from an explicit application config."""
|
||||
channels_config = {}
|
||||
# extra fields are allowed by AppConfig (extra="allow")
|
||||
extra = app_config.model_extra or {}
|
||||
@ -190,12 +181,12 @@ def get_channel_service() -> ChannelService | None:
|
||||
return _channel_service
|
||||
|
||||
|
||||
async def start_channel_service() -> ChannelService:
|
||||
async def start_channel_service(app_config: AppConfig) -> ChannelService:
|
||||
"""Create and start the global ChannelService from app config."""
|
||||
global _channel_service
|
||||
if _channel_service is not None:
|
||||
return _channel_service
|
||||
_channel_service = ChannelService.from_app_config()
|
||||
_channel_service = ChannelService.from_app_config(app_config)
|
||||
await _channel_service.start()
|
||||
return _channel_service
|
||||
|
||||
|
||||
@ -146,11 +146,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
|
||||
try:
|
||||
# app.state.config is the source of truth for Depends(get_config).
|
||||
# AppConfig.init() mirrors it to the process-global for not-yet-migrated
|
||||
# callers; both go away in P2-10 once AppConfig.current() is removed.
|
||||
# ``app.state.config`` is the sole source of truth for
|
||||
# ``Depends(get_config)``. Consumers that want AppConfig must receive
|
||||
# it as an explicit parameter; there is no ambient singleton.
|
||||
app.state.config = AppConfig.from_file()
|
||||
AppConfig.init(app.state.config)
|
||||
logger.info("Configuration loaded successfully")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||
@ -171,7 +170,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
|
||||
channel_service = await start_channel_service()
|
||||
channel_service = await start_channel_service(app.state.config)
|
||||
logger.info("Channel service started: %s", channel_service.get_status())
|
||||
except Exception:
|
||||
logger.exception("No IM channels configured or channel service failed to start")
|
||||
|
||||
@ -53,16 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
# app.state.config is populated earlier in lifespan(); thread it into
|
||||
# every provider that used to reach for AppConfig.current().
|
||||
config = app.state.config
|
||||
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
|
||||
|
||||
# Initialize persistence engine BEFORE checkpointer so that
|
||||
# auto-create-database logic runs first (postgres backend).
|
||||
# Use app.state.config which was populated earlier in lifespan().
|
||||
config = app.state.config
|
||||
await init_engine_from_config(config.database)
|
||||
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
|
||||
app.state.store = await stack.enter_async_context(make_store(config))
|
||||
|
||||
# Initialize repositories — one get_session_factory() call for all.
|
||||
sf = get_session_factory()
|
||||
|
||||
@ -166,12 +166,10 @@ async def update_mcp_configuration(
|
||||
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
|
||||
# will detect config file changes via mtime and reinitialize MCP tools automatically
|
||||
|
||||
# Reload the configuration. Swap app.state.config (new primitive) and
|
||||
# AppConfig.init() (legacy) so both Depends(get_config) and the not-yet-migrated
|
||||
# AppConfig.current() callers see the new config.
|
||||
# Reload the configuration and swap ``app.state.config`` so subsequent
|
||||
# ``Depends(get_config)`` calls see the refreshed value.
|
||||
reloaded = AppConfig.from_file()
|
||||
http_request.app.state.config = reloaded
|
||||
AppConfig.init(reloaded)
|
||||
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded.extensions.mcp_servers.items()})
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -115,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
|
||||
summary="Get Memory Data",
|
||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||
)
|
||||
async def get_memory() -> MemoryResponse:
|
||||
async def get_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
|
||||
"""Get the current global memory data.
|
||||
|
||||
Returns:
|
||||
@ -149,7 +149,7 @@ async def get_memory() -> MemoryResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@ -160,7 +160,7 @@ async def get_memory() -> MemoryResponse:
|
||||
summary="Reload Memory Data",
|
||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||
)
|
||||
async def reload_memory() -> MemoryResponse:
|
||||
async def reload_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
|
||||
"""Reload memory data from file.
|
||||
|
||||
This forces a reload of the memory data from the storage file,
|
||||
@ -169,7 +169,7 @@ async def reload_memory() -> MemoryResponse:
|
||||
Returns:
|
||||
The reloaded memory data.
|
||||
"""
|
||||
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = reload_memory_data(app_config.memory, user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@ -180,10 +180,10 @@ async def reload_memory() -> MemoryResponse:
|
||||
summary="Clear All Memory Data",
|
||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||
)
|
||||
async def clear_memory() -> MemoryResponse:
|
||||
async def clear_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
|
||||
"""Clear all persisted memory data."""
|
||||
try:
|
||||
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = clear_memory_data(app_config.memory, user_id=get_effective_user_id())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||
|
||||
@ -197,10 +197,11 @@ async def clear_memory() -> MemoryResponse:
|
||||
summary="Create Memory Fact",
|
||||
description="Create a single saved memory fact manually.",
|
||||
)
|
||||
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
|
||||
async def create_memory_fact_endpoint(request: FactCreateRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
|
||||
"""Create a single fact manually."""
|
||||
try:
|
||||
memory_data = create_memory_fact(
|
||||
app_config.memory,
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
@ -221,10 +222,10 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||
summary="Delete Memory Fact",
|
||||
description="Delete a single saved memory fact by its fact id.",
|
||||
)
|
||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
async def delete_memory_fact_endpoint(fact_id: str, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
|
||||
"""Delete a single fact from memory by fact id."""
|
||||
try:
|
||||
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
||||
memory_data = delete_memory_fact(app_config.memory, fact_id, user_id=get_effective_user_id())
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
@ -240,10 +241,11 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
summary="Patch Memory Fact",
|
||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||
)
|
||||
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
|
||||
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
|
||||
"""Partially update a single fact manually."""
|
||||
try:
|
||||
memory_data = update_memory_fact(
|
||||
app_config.memory,
|
||||
fact_id=fact_id,
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
@ -267,9 +269,9 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
summary="Export Memory Data",
|
||||
description="Export the current global memory data as JSON for backup or transfer.",
|
||||
)
|
||||
async def export_memory() -> MemoryResponse:
|
||||
async def export_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
|
||||
"""Export the current memory data."""
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@ -280,10 +282,10 @@ async def export_memory() -> MemoryResponse:
|
||||
summary="Import Memory Data",
|
||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||
)
|
||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||
async def import_memory(request: MemoryResponse, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
|
||||
"""Import and persist memory data."""
|
||||
try:
|
||||
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
|
||||
memory_data = import_memory_data(app_config.memory, request.model_dump(), user_id=get_effective_user_id())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||
|
||||
@ -345,7 +347,7 @@ async def get_memory_status(
|
||||
Combined memory configuration and current data.
|
||||
"""
|
||||
config = app_config.memory
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = get_memory_data(config, user_id=get_effective_user_id())
|
||||
|
||||
return MemoryStatusResponse(
|
||||
config=MemoryConfigResponse(
|
||||
|
||||
@ -10,7 +10,7 @@ from app.gateway.deps import get_config
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.skills import Skill, load_skills
|
||||
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
||||
from deerflow.skills.manager import (
|
||||
@ -102,9 +102,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
|
||||
summary="List All Skills",
|
||||
description="Retrieve a list of all available skills from both public and custom directories.",
|
||||
)
|
||||
async def list_skills() -> SkillsListResponse:
|
||||
async def list_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(app_config, enabled_only=False)
|
||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load skills: {e}", exc_info=True)
|
||||
@ -117,11 +117,11 @@ async def list_skills() -> SkillsListResponse:
|
||||
summary="Install Skill",
|
||||
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
|
||||
)
|
||||
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
||||
async def install_skill(request: SkillInstallRequest, app_config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
|
||||
try:
|
||||
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
|
||||
result = install_skill_from_archive(skill_file_path)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
await refresh_skills_system_prompt_cache_async(app_config)
|
||||
return SkillInstallResponse(**result)
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -137,9 +137,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
||||
|
||||
|
||||
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
|
||||
async def list_custom_skills() -> SkillsListResponse:
|
||||
async def list_custom_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
|
||||
try:
|
||||
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
|
||||
skills = [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
|
||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||
except Exception as e:
|
||||
logger.error("Failed to list custom skills: %s", e, exc_info=True)
|
||||
@ -147,13 +147,13 @@ async def list_custom_skills() -> SkillsListResponse:
|
||||
|
||||
|
||||
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
|
||||
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
|
||||
async def get_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(app_config, enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
|
||||
if skill is None:
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
|
||||
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name, app_config))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -162,14 +162,18 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
|
||||
|
||||
|
||||
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
|
||||
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
|
||||
async def update_custom_skill(
|
||||
skill_name: str,
|
||||
request: CustomSkillUpdateRequest,
|
||||
app_config: AppConfig = Depends(get_config),
|
||||
) -> CustomSkillContentResponse:
|
||||
try:
|
||||
ensure_custom_skill_is_editable(skill_name)
|
||||
ensure_custom_skill_is_editable(skill_name, app_config)
|
||||
validate_skill_markdown_content(skill_name, request.content)
|
||||
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||
scan = await scan_skill_content(app_config, request.content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||
if scan.decision == "block":
|
||||
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
|
||||
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
|
||||
skill_file = get_custom_skill_dir(skill_name, app_config) / "SKILL.md"
|
||||
prev_content = skill_file.read_text(encoding="utf-8")
|
||||
atomic_write(skill_file, request.content)
|
||||
append_history(
|
||||
@ -183,9 +187,10 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
|
||||
"new_content": request.content,
|
||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||
},
|
||||
app_config,
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return await get_custom_skill(skill_name)
|
||||
await refresh_skills_system_prompt_cache_async(app_config)
|
||||
return await get_custom_skill(skill_name, app_config)
|
||||
except HTTPException:
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
@ -198,11 +203,11 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
|
||||
|
||||
|
||||
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
|
||||
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||
async def delete_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> dict[str, bool]:
|
||||
try:
|
||||
ensure_custom_skill_is_editable(skill_name)
|
||||
skill_dir = get_custom_skill_dir(skill_name)
|
||||
prev_content = read_custom_skill_content(skill_name)
|
||||
ensure_custom_skill_is_editable(skill_name, app_config)
|
||||
skill_dir = get_custom_skill_dir(skill_name, app_config)
|
||||
prev_content = read_custom_skill_content(skill_name, app_config)
|
||||
append_history(
|
||||
skill_name,
|
||||
{
|
||||
@ -214,9 +219,10 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||
"new_content": None,
|
||||
"scanner": {"decision": "allow", "reason": "Deletion requested."},
|
||||
},
|
||||
app_config,
|
||||
)
|
||||
shutil.rmtree(skill_dir)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
await refresh_skills_system_prompt_cache_async(app_config)
|
||||
return {"success": True}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -228,11 +234,11 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||
|
||||
|
||||
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
|
||||
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
|
||||
async def get_custom_skill_history(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
|
||||
try:
|
||||
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||
if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
return CustomSkillHistoryResponse(history=read_history(skill_name))
|
||||
return CustomSkillHistoryResponse(history=read_history(skill_name, app_config))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -241,11 +247,15 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
|
||||
|
||||
|
||||
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
|
||||
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
|
||||
async def rollback_custom_skill(
|
||||
skill_name: str,
|
||||
request: SkillRollbackRequest,
|
||||
app_config: AppConfig = Depends(get_config),
|
||||
) -> CustomSkillContentResponse:
|
||||
try:
|
||||
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||
if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
history = read_history(skill_name)
|
||||
history = read_history(skill_name, app_config)
|
||||
if not history:
|
||||
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
|
||||
record = history[request.history_index]
|
||||
@ -253,8 +263,8 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
|
||||
if target_content is None:
|
||||
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
|
||||
validate_skill_markdown_content(skill_name, target_content)
|
||||
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||
skill_file = get_custom_skill_file(skill_name)
|
||||
scan = await scan_skill_content(app_config, target_content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||
skill_file = get_custom_skill_file(skill_name, app_config)
|
||||
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
|
||||
history_entry = {
|
||||
"action": "rollback",
|
||||
@ -267,12 +277,12 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
|
||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||
}
|
||||
if scan.decision == "block":
|
||||
append_history(skill_name, history_entry)
|
||||
append_history(skill_name, history_entry, app_config)
|
||||
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
|
||||
atomic_write(skill_file, target_content)
|
||||
append_history(skill_name, history_entry)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return await get_custom_skill(skill_name)
|
||||
append_history(skill_name, history_entry, app_config)
|
||||
await refresh_skills_system_prompt_cache_async(app_config)
|
||||
return await get_custom_skill(skill_name, app_config)
|
||||
except HTTPException:
|
||||
raise
|
||||
except IndexError:
|
||||
@ -292,9 +302,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
|
||||
summary="Get Skill Details",
|
||||
description="Retrieve detailed information about a specific skill by its name.",
|
||||
)
|
||||
async def get_skill(skill_name: str) -> SkillResponse:
|
||||
async def get_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> SkillResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(app_config, enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if skill is None:
|
||||
@ -321,7 +331,7 @@ async def update_skill(
|
||||
app_config: AppConfig = Depends(get_config),
|
||||
) -> SkillResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(app_config, enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if skill is None:
|
||||
@ -332,26 +342,29 @@ async def update_skill(
|
||||
config_path = Path.cwd().parent / "extensions_config.json"
|
||||
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
|
||||
|
||||
# Do not mutate the frozen AppConfig in place. Compose the new skills
|
||||
# state in a fresh dict, write to disk, and reload AppConfig below so
|
||||
# every subsequent Depends(get_config) sees the refreshed snapshot.
|
||||
ext = app_config.extensions
|
||||
ext.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
|
||||
updated_skills = {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()}
|
||||
updated_skills[skill_name] = {"enabled": request.enabled}
|
||||
|
||||
config_data = {
|
||||
"mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()},
|
||||
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()},
|
||||
"skills": updated_skills,
|
||||
}
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
logger.info(f"Skills configuration updated and saved to: {config_path}")
|
||||
# Swap both app.state.config and AppConfig._global so Depends(get_config)
|
||||
# and legacy AppConfig.current() callers see the new config.
|
||||
# Reload AppConfig and swap ``app.state.config`` so subsequent
|
||||
# ``Depends(get_config)`` sees the refreshed value.
|
||||
reloaded = AppConfig.from_file()
|
||||
http_request.app.state.config = reloaded
|
||||
AppConfig.init(reloaded)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
await refresh_skills_system_prompt_cache_async(reloaded)
|
||||
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(reloaded, enabled_only=False)
|
||||
updated_skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if updated_skill is None:
|
||||
|
||||
@ -4,10 +4,12 @@ import logging
|
||||
import os
|
||||
import stat
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
@ -61,6 +63,7 @@ async def upload_files(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
files: list[UploadFile] = File(...),
|
||||
app_config: AppConfig = Depends(get_config),
|
||||
) -> UploadResponse:
|
||||
"""Upload multiple files to a thread's uploads directory."""
|
||||
if not files:
|
||||
@ -73,7 +76,7 @@ async def upload_files(
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
uploaded_files = []
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sandbox_provider = get_sandbox_provider(app_config)
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
|
||||
|
||||
@ -61,9 +61,9 @@ def _create_summarization_middleware(app_config: AppConfig) -> SummarizationMidd
|
||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||
# LangChain built-in, so we tag the model at creation time).
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = create_chat_model(thinking_enabled=False, app_config=app_config)
|
||||
model = model.with_config(tags=["middleware:summarize"])
|
||||
|
||||
# Prepare kwargs
|
||||
@ -288,22 +288,22 @@ def make_lead_agent(
|
||||
Args:
|
||||
config: LangGraph ``RunnableConfig`` carrying per-invocation options
|
||||
(``thinking_enabled``, ``model_name``, ``is_plan_mode``, etc.).
|
||||
app_config: Resolved application config. When omitted falls back to
|
||||
:meth:`AppConfig.current`, preserving backward compatibility with
|
||||
callers that do not thread config explicitly (LangGraph Server
|
||||
registration path). New callers should pass this parameter.
|
||||
app_config: Resolved application config. Required for in-process
|
||||
entry points (DeerFlowClient, Gateway Worker). When omitted we
|
||||
are being called via ``langgraph.json`` registration and reload
|
||||
from disk — the LangGraph Server bootstrap path has no other
|
||||
way to thread the value.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
from deerflow.tools.builtins import setup_agent
|
||||
|
||||
if app_config is None:
|
||||
# LangGraph Server registers make_lead_agent via langgraph.json; its
|
||||
# invocation path only hands us RunnableConfig. Until that registration
|
||||
# layer owns its own AppConfig, we tolerate the process-global fallback
|
||||
# here. All other entry points (DeerFlowClient, Gateway Worker) pass
|
||||
# app_config explicitly. Remove alongside AppConfig.current() in P2-10.
|
||||
app_config = AppConfig.current()
|
||||
# LangGraph Server registers ``make_lead_agent`` via ``langgraph.json``
|
||||
# and hands us only a ``RunnableConfig``. Reload config from disk
|
||||
# here — it's a pure function, equivalent to the process-global the
|
||||
# old code path would have read.
|
||||
app_config = AppConfig.from_file()
|
||||
|
||||
cfg = config.get("configurable", {})
|
||||
|
||||
@ -363,7 +363,7 @@ def make_lead_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=app_config),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=app_config) + [setup_agent],
|
||||
middleware=_build_middlewares(app_config, config, model_name=model_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
|
||||
system_prompt=apply_prompt_template(app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
|
||||
state_schema=ThreadState,
|
||||
context_schema=DeerFlowContext,
|
||||
)
|
||||
@ -374,7 +374,7 @@ def make_lead_agent(
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=app_config),
|
||||
middleware=_build_middlewares(app_config, config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
),
|
||||
state_schema=ThreadState,
|
||||
context_schema=DeerFlowContext,
|
||||
|
||||
@ -20,19 +20,20 @@ _enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event = threading.Event()
|
||||
|
||||
|
||||
def _load_enabled_skills_sync() -> list[Skill]:
|
||||
return list(load_skills(enabled_only=True))
|
||||
def _load_enabled_skills_sync(app_config: AppConfig | None) -> list[Skill]:
|
||||
return list(load_skills(app_config, enabled_only=True))
|
||||
|
||||
|
||||
def _start_enabled_skills_refresh_thread() -> None:
|
||||
def _start_enabled_skills_refresh_thread(app_config: AppConfig | None) -> None:
|
||||
threading.Thread(
|
||||
target=_refresh_enabled_skills_cache_worker,
|
||||
args=(app_config,),
|
||||
name="deerflow-enabled-skills-loader",
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
|
||||
def _refresh_enabled_skills_cache_worker() -> None:
|
||||
def _refresh_enabled_skills_cache_worker(app_config: AppConfig | None) -> None:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active
|
||||
|
||||
while True:
|
||||
@ -40,8 +41,8 @@ def _refresh_enabled_skills_cache_worker() -> None:
|
||||
target_version = _enabled_skills_refresh_version
|
||||
|
||||
try:
|
||||
skills = _load_enabled_skills_sync()
|
||||
except Exception:
|
||||
skills = _load_enabled_skills_sync(app_config)
|
||||
except (OSError, ImportError):
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
skills = []
|
||||
|
||||
@ -57,7 +58,7 @@ def _refresh_enabled_skills_cache_worker() -> None:
|
||||
_enabled_skills_cache = None
|
||||
|
||||
|
||||
def _ensure_enabled_skills_cache() -> threading.Event:
|
||||
def _ensure_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
|
||||
global _enabled_skills_refresh_active
|
||||
|
||||
with _enabled_skills_lock:
|
||||
@ -69,11 +70,11 @@ def _ensure_enabled_skills_cache() -> threading.Event:
|
||||
_enabled_skills_refresh_active = True
|
||||
_enabled_skills_refresh_event.clear()
|
||||
|
||||
_start_enabled_skills_refresh_thread()
|
||||
_start_enabled_skills_refresh_thread(app_config)
|
||||
return _enabled_skills_refresh_event
|
||||
|
||||
|
||||
def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
def _invalidate_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
|
||||
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
@ -85,30 +86,30 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
return _enabled_skills_refresh_event
|
||||
_enabled_skills_refresh_active = True
|
||||
|
||||
_start_enabled_skills_refresh_thread()
|
||||
_start_enabled_skills_refresh_thread(app_config)
|
||||
return _enabled_skills_refresh_event
|
||||
|
||||
|
||||
def prime_enabled_skills_cache() -> None:
|
||||
_ensure_enabled_skills_cache()
|
||||
def prime_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
|
||||
_ensure_enabled_skills_cache(app_config)
|
||||
|
||||
|
||||
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
|
||||
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds):
|
||||
def warm_enabled_skills_cache(app_config: AppConfig | None = None, timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
|
||||
if _ensure_enabled_skills_cache(app_config).wait(timeout=timeout_seconds):
|
||||
return True
|
||||
|
||||
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
|
||||
return False
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
def _get_enabled_skills(app_config: AppConfig | None = None):
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_cache
|
||||
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
_ensure_enabled_skills_cache()
|
||||
_ensure_enabled_skills_cache(app_config)
|
||||
return []
|
||||
|
||||
|
||||
@ -116,12 +117,12 @@ def _skill_mutability_label(category: str) -> str:
|
||||
return "[custom, editable]" if category == "custom" else "[built-in]"
|
||||
|
||||
|
||||
def clear_skills_system_prompt_cache() -> None:
|
||||
_invalidate_enabled_skills_cache()
|
||||
def clear_skills_system_prompt_cache(app_config: AppConfig | None = None) -> None:
|
||||
_invalidate_enabled_skills_cache(app_config)
|
||||
|
||||
|
||||
async def refresh_skills_system_prompt_cache_async() -> None:
|
||||
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
|
||||
async def refresh_skills_system_prompt_cache_async(app_config: AppConfig | None = None) -> None:
|
||||
await asyncio.to_thread(_invalidate_enabled_skills_cache(app_config).wait)
|
||||
|
||||
|
||||
def _reset_skills_system_prompt_cache_state() -> None:
|
||||
@ -135,10 +136,10 @@ def _reset_skills_system_prompt_cache_state() -> None:
|
||||
_enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def _refresh_enabled_skills_cache() -> None:
|
||||
def _refresh_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
|
||||
"""Backward-compatible test helper for direct synchronous reload."""
|
||||
try:
|
||||
skills = _load_enabled_skills_sync()
|
||||
skills = _load_enabled_skills_sync(app_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
skills = []
|
||||
@ -165,17 +166,18 @@ Skip simple one-off tasks.
|
||||
"""
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
def _build_subagent_section(max_concurrent: int, app_config: AppConfig) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent subagent calls allowed per response.
|
||||
app_config: Application config used to gate bash availability.
|
||||
|
||||
Returns:
|
||||
Formatted subagent section string.
|
||||
"""
|
||||
n = max_concurrent
|
||||
bash_available = "bash" in get_available_subagent_names()
|
||||
bash_available = "bash" in get_available_subagent_names(app_config)
|
||||
available_subagents = (
|
||||
"- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n- **bash**: For command execution (git, build, test, deploy operations)"
|
||||
if bash_available
|
||||
@ -508,36 +510,34 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
def _get_memory_context(app_config: AppConfig, agent_name: str | None = None) -> str:
|
||||
"""Get memory context for injection into system prompt.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
|
||||
Returns:
|
||||
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
||||
Returns an empty string when memory is disabled or the stored memory file
|
||||
cannot be read/parsed. A corrupt memory.json degrades the prompt to
|
||||
no-memory; it never kills the agent.
|
||||
"""
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
memory_config = app_config.memory
|
||||
if not memory_config.enabled or not memory_config.injection_enabled:
|
||||
return ""
|
||||
|
||||
try:
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
memory_data = get_memory_data(memory_config, agent_name, user_id=get_effective_user_id())
|
||||
except (OSError, ValueError, UnicodeDecodeError):
|
||||
logger.exception("Failed to load memory data for prompt injection")
|
||||
return ""
|
||||
|
||||
config = AppConfig.current().memory
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=memory_config.max_injection_tokens)
|
||||
if not memory_content.strip():
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
|
||||
if not memory_content.strip():
|
||||
return ""
|
||||
|
||||
return f"""<memory>
|
||||
return f"""<memory>
|
||||
{memory_content}
|
||||
</memory>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error("Failed to load memory context: %s", e)
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
@ -572,17 +572,12 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
||||
</skill_system>"""
|
||||
|
||||
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
def get_skills_prompt_section(app_config: AppConfig, available_skills: set[str] | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills()
|
||||
skills = _get_enabled_skills(app_config)
|
||||
|
||||
try:
|
||||
config = AppConfig.current()
|
||||
container_base_path = config.skills.container_path
|
||||
skill_evolution_enabled = config.skill_evolution.enabled
|
||||
except Exception:
|
||||
container_base_path = "/mnt/skills"
|
||||
skill_evolution_enabled = False
|
||||
container_base_path = app_config.skills.container_path
|
||||
skill_evolution_enabled = app_config.skill_evolution.enabled
|
||||
|
||||
if not skills and not skill_evolution_enabled:
|
||||
return ""
|
||||
@ -606,7 +601,7 @@ def get_agent_soul(agent_name: str | None) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def get_deferred_tools_prompt_section() -> str:
|
||||
def get_deferred_tools_prompt_section(app_config: AppConfig) -> str:
|
||||
"""Generate <available-deferred-tools> block for the system prompt.
|
||||
|
||||
Lists only deferred tool names so the agent knows what exists
|
||||
@ -615,10 +610,7 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
|
||||
try:
|
||||
if not AppConfig.current().tool_search.enabled:
|
||||
return ""
|
||||
except Exception:
|
||||
if not app_config.tool_search.enabled:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
@ -629,13 +621,9 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
||||
|
||||
|
||||
def _build_acp_section() -> str:
|
||||
def _build_acp_section(app_config: AppConfig) -> str:
|
||||
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
||||
try:
|
||||
agents = AppConfig.current().acp_agents
|
||||
if not agents:
|
||||
return ""
|
||||
except Exception:
|
||||
if not app_config.acp_agents:
|
||||
return ""
|
||||
|
||||
return (
|
||||
@ -647,13 +635,9 @@ def _build_acp_section() -> str:
|
||||
)
|
||||
|
||||
|
||||
def _build_custom_mounts_section() -> str:
|
||||
def _build_custom_mounts_section(app_config: AppConfig) -> str:
|
||||
"""Build a prompt section for explicitly configured sandbox mounts."""
|
||||
try:
|
||||
mounts = AppConfig.current().sandbox.mounts or []
|
||||
except Exception:
|
||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||
return ""
|
||||
mounts = app_config.sandbox.mounts or []
|
||||
|
||||
if not mounts:
|
||||
return ""
|
||||
@ -667,13 +651,20 @@ def _build_custom_mounts_section() -> str:
|
||||
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
|
||||
|
||||
|
||||
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
||||
def apply_prompt_template(
|
||||
app_config: AppConfig,
|
||||
subagent_enabled: bool = False,
|
||||
max_concurrent_subagents: int = 3,
|
||||
*,
|
||||
agent_name: str | None = None,
|
||||
available_skills: set[str] | None = None,
|
||||
) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name)
|
||||
memory_context = _get_memory_context(app_config, agent_name)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
|
||||
subagent_section = _build_subagent_section(n, app_config) if subagent_enabled else ""
|
||||
|
||||
# Add subagent reminder to critical_reminders if enabled
|
||||
subagent_reminder = (
|
||||
@ -694,14 +685,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
|
||||
)
|
||||
|
||||
# Get skills section
|
||||
skills_section = get_skills_prompt_section(available_skills)
|
||||
skills_section = get_skills_prompt_section(app_config, available_skills)
|
||||
|
||||
# Get deferred tools section (tool_search)
|
||||
deferred_tools_section = get_deferred_tools_prompt_section()
|
||||
deferred_tools_section = get_deferred_tools_prompt_section(app_config)
|
||||
|
||||
# Build ACP agent section only if ACP agents are configured
|
||||
acp_section = _build_acp_section()
|
||||
custom_mounts_section = _build_custom_mounts_section()
|
||||
acp_section = _build_acp_section(app_config)
|
||||
custom_mounts_section = _build_custom_mounts_section(app_config)
|
||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||
|
||||
# Format the prompt with dynamic skills and memory
|
||||
|
||||
@ -12,6 +12,12 @@ from deerflow.config.app_config import AppConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Module-level config pointer set by the middleware that owns the queue.
|
||||
# The queue runs on a background Timer thread where ``Runtime`` and FastAPI
|
||||
# request context are not accessible; the enqueuer (which does have runtime
|
||||
# context) is responsible for plumbing ``AppConfig`` through ``add()``.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""Context for a conversation to be processed for memory update."""
|
||||
@ -31,10 +37,21 @@ class MemoryUpdateQueue:
|
||||
This queue collects conversation contexts and processes them after
|
||||
a configurable debounce period. Multiple conversations received within
|
||||
the debounce window are batched together.
|
||||
|
||||
The queue captures an ``AppConfig`` reference at construction time and
|
||||
reuses it for the MemoryUpdater it spawns. Callers must construct a
|
||||
fresh queue when the config changes rather than reaching into a global.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the memory update queue."""
|
||||
def __init__(self, app_config: AppConfig):
|
||||
"""Initialize the memory update queue.
|
||||
|
||||
Args:
|
||||
app_config: Application config. The queue reads its own
|
||||
``memory`` section for debounce timing and hands the full
|
||||
config to :class:`MemoryUpdater`.
|
||||
"""
|
||||
self._app_config = app_config
|
||||
self._queue: list[ConversationContext] = []
|
||||
self._lock = threading.Lock()
|
||||
self._timer: threading.Timer | None = None
|
||||
@ -49,19 +66,8 @@ class MemoryUpdateQueue:
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
|
||||
survives the threading.Timer boundary (ContextVar does not propagate across
|
||||
raw threads).
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
config = AppConfig.current().memory
|
||||
"""Add a conversation to the update queue."""
|
||||
config = self._app_config.memory
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
@ -93,7 +99,7 @@ class MemoryUpdateQueue:
|
||||
|
||||
def _reset_timer(self) -> None:
|
||||
"""Reset the debounce timer."""
|
||||
config = AppConfig.current().memory
|
||||
config = self._app_config.memory
|
||||
|
||||
# Cancel existing timer if any
|
||||
if self._timer is not None:
|
||||
@ -131,7 +137,7 @@ class MemoryUpdateQueue:
|
||||
logger.info("Processing %d queued memory updates", len(contexts_to_process))
|
||||
|
||||
try:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(self._app_config)
|
||||
|
||||
for context in contexts_to_process:
|
||||
try:
|
||||
@ -196,31 +202,35 @@ class MemoryUpdateQueue:
|
||||
return self._processing
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_memory_queue: MemoryUpdateQueue | None = None
|
||||
# Queues keyed by ``id(AppConfig)`` so tests and multi-client setups with
|
||||
# distinct configs do not share a debounce queue.
|
||||
_memory_queues: dict[int, MemoryUpdateQueue] = {}
|
||||
_queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_queue() -> MemoryUpdateQueue:
|
||||
"""Get the global memory update queue singleton.
|
||||
|
||||
Returns:
|
||||
The memory update queue instance.
|
||||
"""
|
||||
global _memory_queue
|
||||
def get_memory_queue(app_config: AppConfig) -> MemoryUpdateQueue:
|
||||
"""Get or create the memory update queue for the given app config."""
|
||||
key = id(app_config)
|
||||
with _queue_lock:
|
||||
if _memory_queue is None:
|
||||
_memory_queue = MemoryUpdateQueue()
|
||||
return _memory_queue
|
||||
queue = _memory_queues.get(key)
|
||||
if queue is None:
|
||||
queue = MemoryUpdateQueue(app_config)
|
||||
_memory_queues[key] = queue
|
||||
return queue
|
||||
|
||||
|
||||
def reset_memory_queue() -> None:
|
||||
"""Reset the global memory queue.
|
||||
def reset_memory_queue(app_config: AppConfig | None = None) -> None:
|
||||
"""Reset memory queue(s).
|
||||
|
||||
This is useful for testing.
|
||||
Pass an ``app_config`` to reset only its queue, or omit to reset all
|
||||
(useful at test teardown).
|
||||
"""
|
||||
global _memory_queue
|
||||
with _queue_lock:
|
||||
if _memory_queue is not None:
|
||||
_memory_queue.clear()
|
||||
_memory_queue = None
|
||||
if app_config is not None:
|
||||
queue = _memory_queues.pop(id(app_config), None)
|
||||
if queue is not None:
|
||||
queue.clear()
|
||||
return
|
||||
for queue in _memory_queues.values():
|
||||
queue.clear()
|
||||
_memory_queues.clear()
|
||||
|
||||
@ -9,7 +9,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.agents_config import AGENT_NAME_PATTERN
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -61,8 +61,15 @@ class MemoryStorage(abc.ABC):
|
||||
class FileMemoryStorage(MemoryStorage):
|
||||
"""File-based memory storage provider."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the file memory storage."""
|
||||
def __init__(self, memory_config: MemoryConfig):
|
||||
"""Initialize the file memory storage.
|
||||
|
||||
Args:
|
||||
memory_config: Memory configuration (storage_path etc.). Stored on
|
||||
the instance so per-request lookups don't need to reach for
|
||||
ambient state.
|
||||
"""
|
||||
self._memory_config = memory_config
|
||||
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
||||
@ -80,11 +87,11 @@ class FileMemoryStorage(MemoryStorage):
|
||||
|
||||
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
|
||||
"""Get the path to the memory file."""
|
||||
config = self._memory_config
|
||||
if user_id is not None:
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().user_agent_memory_file(user_id, agent_name)
|
||||
config = AppConfig.current().memory
|
||||
if config.storage_path and Path(config.storage_path).is_absolute():
|
||||
return Path(config.storage_path)
|
||||
return get_paths().user_memory_file(user_id)
|
||||
@ -92,7 +99,6 @@ class FileMemoryStorage(MemoryStorage):
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().agent_memory_file(agent_name)
|
||||
config = AppConfig.current().memory
|
||||
if config.storage_path:
|
||||
p = Path(config.storage_path)
|
||||
return p if p.is_absolute() else get_paths().base_dir / p
|
||||
@ -174,23 +180,31 @@ class FileMemoryStorage(MemoryStorage):
|
||||
return False
|
||||
|
||||
|
||||
_storage_instance: MemoryStorage | None = None
|
||||
# Instances keyed by (storage_class_path, id(memory_config)) so tests can
|
||||
# construct isolated storages and multi-client setups with different configs
|
||||
# don't collide on a single process-wide singleton.
|
||||
_storage_instances: dict[tuple[str, int], MemoryStorage] = {}
|
||||
_storage_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_storage() -> MemoryStorage:
|
||||
"""Get the configured memory storage instance."""
|
||||
global _storage_instance
|
||||
if _storage_instance is not None:
|
||||
return _storage_instance
|
||||
def get_memory_storage(memory_config: MemoryConfig) -> MemoryStorage:
|
||||
"""Get the configured memory storage instance.
|
||||
|
||||
Caches one instance per ``(storage_class, memory_config)`` pair. In
|
||||
single-config deployments this collapses to one instance; in multi-client
|
||||
or test scenarios each config gets its own storage.
|
||||
"""
|
||||
key = (memory_config.storage_class, id(memory_config))
|
||||
existing = _storage_instances.get(key)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
with _storage_lock:
|
||||
if _storage_instance is not None:
|
||||
return _storage_instance
|
||||
|
||||
config = AppConfig.current().memory
|
||||
storage_class_path = config.storage_class
|
||||
existing = _storage_instances.get(key)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
storage_class_path = memory_config.storage_class
|
||||
try:
|
||||
module_path, class_name = storage_class_path.rsplit(".", 1)
|
||||
import importlib
|
||||
@ -204,13 +218,14 @@ def get_memory_storage() -> MemoryStorage:
|
||||
if not issubclass(storage_class, MemoryStorage):
|
||||
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
|
||||
|
||||
_storage_instance = storage_class()
|
||||
instance = storage_class(memory_config)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load memory storage %s, falling back to FileMemoryStorage: %s",
|
||||
storage_class_path,
|
||||
e,
|
||||
)
|
||||
_storage_instance = FileMemoryStorage()
|
||||
instance = FileMemoryStorage(memory_config)
|
||||
|
||||
return _storage_instance
|
||||
_storage_instances[key] = instance
|
||||
return instance
|
||||
|
||||
@ -17,6 +17,7 @@ from deerflow.agents.memory.storage import (
|
||||
utc_now_iso_z,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -27,45 +28,33 @@ def _create_empty_memory() -> dict[str, Any]:
|
||||
return create_empty_memory()
|
||||
|
||||
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
|
||||
def _save_memory_to_file(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Save via the configured memory storage."""
|
||||
return get_memory_storage(memory_config).save(memory_data, agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def get_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Get the current memory data via storage provider."""
|
||||
return get_memory_storage().load(agent_name, user_id=user_id)
|
||||
return get_memory_storage(memory_config).load(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def reload_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data via storage provider."""
|
||||
return get_memory_storage().reload(agent_name, user_id=user_id)
|
||||
return get_memory_storage(memory_config).reload(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Persist imported memory data via storage provider.
|
||||
|
||||
Args:
|
||||
memory_data: Full memory payload to persist.
|
||||
agent_name: If provided, imports into per-agent memory.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
The saved memory data after storage normalization.
|
||||
|
||||
Raises:
|
||||
OSError: If persisting the imported memory fails.
|
||||
"""
|
||||
storage = get_memory_storage()
|
||||
def import_memory_data(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Persist imported memory data via storage provider."""
|
||||
storage = get_memory_storage(memory_config)
|
||||
if not storage.save(memory_data, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save imported memory data")
|
||||
return storage.load(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def clear_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Clear all stored memory data and persist an empty structure."""
|
||||
cleared_memory = create_empty_memory()
|
||||
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
|
||||
if not _save_memory_to_file(memory_config, cleared_memory, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save cleared memory data")
|
||||
return cleared_memory
|
||||
|
||||
@ -78,6 +67,7 @@ def _validate_confidence(confidence: float) -> float:
|
||||
|
||||
|
||||
def create_memory_fact(
|
||||
memory_config: MemoryConfig,
|
||||
content: str,
|
||||
category: str = "context",
|
||||
confidence: float = 0.5,
|
||||
@ -93,7 +83,7 @@ def create_memory_fact(
|
||||
normalized_category = category.strip() or "context"
|
||||
validated_confidence = _validate_confidence(confidence)
|
||||
now = utc_now_iso_z()
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
|
||||
updated_memory = dict(memory_data)
|
||||
facts = list(memory_data.get("facts", []))
|
||||
facts.append(
|
||||
@ -108,15 +98,15 @@ def create_memory_fact(
|
||||
)
|
||||
updated_memory["facts"] = facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save memory data after creating fact")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def delete_memory_fact(memory_config: MemoryConfig, fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Delete a fact by its id and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
|
||||
facts = memory_data.get("facts", [])
|
||||
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
||||
if len(updated_facts) == len(facts):
|
||||
@ -125,13 +115,14 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id:
|
||||
updated_memory = dict(memory_data)
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def update_memory_fact(
|
||||
memory_config: MemoryConfig,
|
||||
fact_id: str,
|
||||
content: str | None = None,
|
||||
category: str | None = None,
|
||||
@ -141,7 +132,7 @@ def update_memory_fact(
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing fact and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
|
||||
updated_memory = dict(memory_data)
|
||||
updated_facts: list[dict[str, Any]] = []
|
||||
found = False
|
||||
@ -168,7 +159,7 @@ def update_memory_fact(
|
||||
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
@ -260,19 +251,25 @@ def _fact_content_key(content: Any) -> str | None:
|
||||
class MemoryUpdater:
|
||||
"""Updates memory using LLM based on conversation context."""
|
||||
|
||||
def __init__(self, model_name: str | None = None):
|
||||
def __init__(self, app_config: AppConfig, model_name: str | None = None):
|
||||
"""Initialize the memory updater.
|
||||
|
||||
Args:
|
||||
app_config: Application config (the updater needs both ``memory``
|
||||
section for behavior and the full config for ``create_chat_model``).
|
||||
model_name: Optional model name to use. If None, uses config or default.
|
||||
"""
|
||||
self._app_config = app_config
|
||||
self._model_name = model_name
|
||||
|
||||
@property
|
||||
def _memory_config(self) -> MemoryConfig:
|
||||
return self._app_config.memory
|
||||
|
||||
def _get_model(self):
|
||||
"""Get the model for memory updates."""
|
||||
config = AppConfig.current().memory
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
model_name = self._model_name or self._memory_config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False, app_config=self._app_config)
|
||||
|
||||
def update_memory(
|
||||
self,
|
||||
@ -296,7 +293,7 @@ class MemoryUpdater:
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
"""
|
||||
config = AppConfig.current().memory
|
||||
config = self._memory_config
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
@ -305,7 +302,7 @@ class MemoryUpdater:
|
||||
|
||||
try:
|
||||
# Get current memory
|
||||
current_memory = get_memory_data(agent_name, user_id=user_id)
|
||||
current_memory = get_memory_data(config, agent_name, user_id=user_id)
|
||||
|
||||
# Format conversation for prompt
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
@ -360,7 +357,7 @@ class MemoryUpdater:
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
|
||||
# Save
|
||||
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
||||
return get_memory_storage(config).save(updated_memory, agent_name, user_id=user_id)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
@ -385,7 +382,7 @@ class MemoryUpdater:
|
||||
Returns:
|
||||
Updated memory data.
|
||||
"""
|
||||
config = AppConfig.current().memory
|
||||
config = self._memory_config
|
||||
now = utc_now_iso_z()
|
||||
|
||||
# Update user sections
|
||||
|
||||
@ -236,7 +236,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
# threading.Timer fires on a different thread where ContextVar values are not
|
||||
# propagated, so we must store user_id explicitly in ConversationContext.
|
||||
user_id = get_effective_user_id()
|
||||
queue = get_memory_queue()
|
||||
queue = get_memory_queue(runtime.context.app_config)
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
|
||||
@ -8,6 +8,7 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
from deerflow.models import create_chat_model
|
||||
@ -120,8 +121,9 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
_, user_msg = self._build_title_prompt(state, title_config)
|
||||
return {"title": self._fallback_title(user_msg, title_config)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState, app_config: AppConfig) -> dict | None:
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
title_config = app_config.title
|
||||
if not self._should_generate_title(state, title_config):
|
||||
return None
|
||||
|
||||
@ -129,9 +131,9 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
try:
|
||||
if title_config.model_name:
|
||||
model = create_chat_model(name=title_config.model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=title_config.model_name, thinking_enabled=False, app_config=app_config)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = create_chat_model(thinking_enabled=False, app_config=app_config)
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content, title_config)
|
||||
if title:
|
||||
@ -146,4 +148,4 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return await self._agenerate_title_result(state, runtime.context.app_config.title)
|
||||
return await self._agenerate_title_result(state, runtime.context.app_config)
|
||||
|
||||
@ -38,7 +38,7 @@ from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import AGENT_NAME_PATTERN
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
@ -151,16 +151,15 @@ class DeerFlowClient:
|
||||
# Constructor-captured config: the client owns its AppConfig for its lifetime.
|
||||
# Multiple clients with different configs do not contend.
|
||||
#
|
||||
# Priority: explicit ``config=`` > explicit ``config_path=`` > AppConfig.current().
|
||||
# The third tier preserves backward compatibility with callers that relied on
|
||||
# the process-global (tests via the conftest autouse fixture). After P2-10
|
||||
# removes AppConfig.current(), this fallback will require an explicit choice.
|
||||
# Priority: explicit ``config=`` > explicit ``config_path=`` > ``AppConfig.from_file()``
|
||||
# with default path resolution. There is no ambient global fallback; if
|
||||
# config.yaml cannot be located, ``from_file`` raises loudly.
|
||||
if config is not None:
|
||||
self._app_config = config
|
||||
elif config_path is not None:
|
||||
self._app_config = AppConfig.from_file(config_path)
|
||||
else:
|
||||
self._app_config = AppConfig.current()
|
||||
self._app_config = AppConfig.from_file()
|
||||
|
||||
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
|
||||
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
|
||||
@ -254,10 +253,11 @@ class DeerFlowClient:
|
||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
|
||||
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
||||
"middleware": _build_middlewares(self._app_config, config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
|
||||
"system_prompt": apply_prompt_template(
|
||||
self._app_config,
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
agent_name=self._agent_name,
|
||||
@ -269,7 +269,7 @@ class DeerFlowClient:
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
checkpointer = get_checkpointer(self._app_config)
|
||||
if checkpointer is not None:
|
||||
kwargs["checkpointer"] = checkpointer
|
||||
|
||||
@ -277,12 +277,11 @@ class DeerFlowClient:
|
||||
self._agent_config_key = key
|
||||
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
|
||||
|
||||
@staticmethod
|
||||
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
|
||||
def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
|
||||
"""Lazy import to avoid circular dependency at module level."""
|
||||
from deerflow.tools import get_available_tools
|
||||
|
||||
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
||||
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=self._app_config)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_tool_calls(tool_calls) -> list[dict]:
|
||||
@ -403,7 +402,7 @@ class DeerFlowClient:
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
checkpointer = get_checkpointer(self._app_config)
|
||||
|
||||
thread_info_map = {}
|
||||
|
||||
@ -458,7 +457,7 @@ class DeerFlowClient:
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
checkpointer = get_checkpointer(self._app_config)
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
checkpoints = []
|
||||
@ -782,7 +781,7 @@ class DeerFlowClient:
|
||||
"category": s.category,
|
||||
"enabled": s.enabled,
|
||||
}
|
||||
for s in load_skills(enabled_only=enabled_only)
|
||||
for s in load_skills(self._app_config, enabled_only=enabled_only)
|
||||
]
|
||||
}
|
||||
|
||||
@ -794,19 +793,19 @@ class DeerFlowClient:
|
||||
"""
|
||||
from deerflow.agents.memory.updater import get_memory_data
|
||||
|
||||
return get_memory_data(user_id=get_effective_user_id())
|
||||
return get_memory_data(self._app_config.memory, user_id=get_effective_user_id())
|
||||
|
||||
def export_memory(self) -> dict:
|
||||
"""Export current memory data for backup or transfer."""
|
||||
from deerflow.agents.memory.updater import get_memory_data
|
||||
|
||||
return get_memory_data(user_id=get_effective_user_id())
|
||||
return get_memory_data(self._app_config.memory, user_id=get_effective_user_id())
|
||||
|
||||
def import_memory(self, memory_data: dict) -> dict:
|
||||
"""Import and persist full memory data."""
|
||||
from deerflow.agents.memory.updater import import_memory_data
|
||||
|
||||
return import_memory_data(memory_data, user_id=get_effective_user_id())
|
||||
return import_memory_data(self._app_config.memory, memory_data, user_id=get_effective_user_id())
|
||||
|
||||
def get_model(self, name: str) -> dict | None:
|
||||
"""Get a specific model's configuration by name.
|
||||
@ -894,7 +893,7 @@ class DeerFlowClient:
|
||||
"""
|
||||
from deerflow.skills.loader import load_skills
|
||||
|
||||
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
|
||||
skill = next((s for s in load_skills(self._app_config, enabled_only=False) if s.name == name), None)
|
||||
if skill is None:
|
||||
return None
|
||||
return {
|
||||
@ -921,7 +920,7 @@ class DeerFlowClient:
|
||||
"""
|
||||
from deerflow.skills.loader import load_skills
|
||||
|
||||
skills = load_skills(enabled_only=False)
|
||||
skills = load_skills(self._app_config, enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == name), None)
|
||||
if skill is None:
|
||||
raise ValueError(f"Skill '{name}' not found")
|
||||
@ -930,12 +929,16 @@ class DeerFlowClient:
|
||||
if config_path is None:
|
||||
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
|
||||
|
||||
# Do not mutate self._app_config (frozen value). Compose the new
|
||||
# skills state in a fresh dict, write it to disk, and let _reload_config()
|
||||
# below rebuild AppConfig from the updated file.
|
||||
ext = self._app_config.extensions
|
||||
ext.skills[name] = SkillStateConfig(enabled=enabled)
|
||||
new_skills = {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()}
|
||||
new_skills[name] = {"enabled": enabled}
|
||||
|
||||
config_data = {
|
||||
"mcpServers": {n: s.model_dump() for n, s in ext.mcp_servers.items()},
|
||||
"skills": {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()},
|
||||
"skills": new_skills,
|
||||
}
|
||||
|
||||
self._atomic_write_json(config_path, config_data)
|
||||
@ -944,7 +947,7 @@ class DeerFlowClient:
|
||||
self._agent_config_key = None
|
||||
self._reload_config()
|
||||
|
||||
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
|
||||
updated = next((s for s in load_skills(self._app_config, enabled_only=False) if s.name == name), None)
|
||||
if updated is None:
|
||||
raise RuntimeError(f"Skill '{name}' disappeared after update")
|
||||
return {
|
||||
@ -982,25 +985,25 @@ class DeerFlowClient:
|
||||
"""
|
||||
from deerflow.agents.memory.updater import reload_memory_data
|
||||
|
||||
return reload_memory_data(user_id=get_effective_user_id())
|
||||
return reload_memory_data(self._app_config.memory, user_id=get_effective_user_id())
|
||||
|
||||
def clear_memory(self) -> dict:
|
||||
"""Clear all persisted memory data."""
|
||||
from deerflow.agents.memory.updater import clear_memory_data
|
||||
|
||||
return clear_memory_data(user_id=get_effective_user_id())
|
||||
return clear_memory_data(self._app_config.memory, user_id=get_effective_user_id())
|
||||
|
||||
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
|
||||
"""Create a single fact manually."""
|
||||
from deerflow.agents.memory.updater import create_memory_fact
|
||||
|
||||
return create_memory_fact(content=content, category=category, confidence=confidence)
|
||||
return create_memory_fact(self._app_config.memory, content=content, category=category, confidence=confidence)
|
||||
|
||||
def delete_memory_fact(self, fact_id: str) -> dict:
|
||||
"""Delete a single fact from memory by fact id."""
|
||||
from deerflow.agents.memory.updater import delete_memory_fact
|
||||
|
||||
return delete_memory_fact(fact_id)
|
||||
return delete_memory_fact(self._app_config.memory, fact_id)
|
||||
|
||||
def update_memory_fact(
|
||||
self,
|
||||
@ -1013,6 +1016,7 @@ class DeerFlowClient:
|
||||
from deerflow.agents.memory.updater import update_memory_fact
|
||||
|
||||
return update_memory_fact(
|
||||
self._app_config.memory,
|
||||
fact_id=fact_id,
|
||||
content=content,
|
||||
category=category,
|
||||
|
||||
@ -90,7 +90,8 @@ class AioSandboxProvider(SandboxProvider):
|
||||
API_KEY: $MY_API_KEY
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, app_config: "AppConfig"):
|
||||
self._app_config = app_config
|
||||
self._lock = threading.Lock()
|
||||
self._sandboxes: dict[str, AioSandbox] = {} # sandbox_id -> AioSandbox instance
|
||||
self._sandbox_infos: dict[str, SandboxInfo] = {} # sandbox_id -> SandboxInfo (for destroy)
|
||||
@ -149,8 +150,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""Load sandbox configuration from app config."""
|
||||
config = AppConfig.current()
|
||||
sandbox_config = config.sandbox
|
||||
sandbox_config = self._app_config.sandbox
|
||||
|
||||
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
|
||||
replicas = getattr(sandbox_config, "replicas", None)
|
||||
@ -273,17 +273,15 @@ class AioSandboxProvider(SandboxProvider):
|
||||
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_skills_mount() -> tuple[str, str, bool] | None:
|
||||
def _get_skills_mount(self) -> tuple[str, str, bool] | None:
|
||||
"""Get the skills directory mount configuration.
|
||||
|
||||
Mount source uses DEER_FLOW_HOST_SKILLS_PATH when running inside Docker (DooD)
|
||||
so the host Docker daemon can resolve the path.
|
||||
"""
|
||||
try:
|
||||
config = AppConfig.current()
|
||||
skills_path = config.skills.get_skills_path()
|
||||
container_path = config.skills.container_path
|
||||
skills_path = self._app_config.skills.get_skills_path()
|
||||
container_path = self._app_config.skills.container_path
|
||||
|
||||
if skills_path.exists():
|
||||
# When running inside Docker with DooD, use host-side skills path.
|
||||
|
||||
@ -5,9 +5,9 @@ Web Search Tool - Search the web using DuckDuckGo (no API key required).
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.tools import tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -55,6 +55,7 @@ def _search_text(
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(
|
||||
query: str,
|
||||
runtime: ToolRuntime,
|
||||
max_results: int = 5,
|
||||
) -> str:
|
||||
"""Search the web for information. Use this tool to find current information, news, articles, and facts from the internet.
|
||||
@ -63,11 +64,11 @@ def web_search_tool(
|
||||
query: Search keywords describing what you want to find. Be specific for better results.
|
||||
max_results: Maximum number of results to return. Default is 5.
|
||||
"""
|
||||
config = AppConfig.current().get_tool_config("web_search")
|
||||
tool_config = resolve_context(runtime).app_config.get_tool_config("web_search")
|
||||
|
||||
# Override max_results from config if set
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
if tool_config is not None and "max_results" in tool_config.model_extra:
|
||||
max_results = tool_config.model_extra.get("max_results", max_results)
|
||||
|
||||
results = _search_text(
|
||||
query=query,
|
||||
|
||||
@ -1,37 +1,39 @@
|
||||
import json
|
||||
|
||||
from exa_py import Exa
|
||||
from langchain.tools import tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
|
||||
|
||||
def _get_exa_client(tool_name: str = "web_search") -> Exa:
|
||||
config = AppConfig.current().get_tool_config(tool_name)
|
||||
def _get_exa_client(app_config: AppConfig, tool_name: str = "web_search") -> Exa:
|
||||
tool_config = app_config.get_tool_config(tool_name)
|
||||
api_key = None
|
||||
if config is not None and "api_key" in config.model_extra:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
if tool_config is not None and "api_key" in tool_config.model_extra:
|
||||
api_key = tool_config.model_extra.get("api_key")
|
||||
return Exa(api_key=api_key)
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str) -> str:
|
||||
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
|
||||
"""Search the web.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
try:
|
||||
config = AppConfig.current().get_tool_config("web_search")
|
||||
app_config = resolve_context(runtime).app_config
|
||||
tool_config = app_config.get_tool_config("web_search")
|
||||
max_results = 5
|
||||
search_type = "auto"
|
||||
contents_max_characters = 1000
|
||||
if config is not None:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
search_type = config.model_extra.get("search_type", search_type)
|
||||
contents_max_characters = config.model_extra.get("contents_max_characters", contents_max_characters)
|
||||
if tool_config is not None:
|
||||
max_results = tool_config.model_extra.get("max_results", max_results)
|
||||
search_type = tool_config.model_extra.get("search_type", search_type)
|
||||
contents_max_characters = tool_config.model_extra.get("contents_max_characters", contents_max_characters)
|
||||
|
||||
client = _get_exa_client()
|
||||
client = _get_exa_client(app_config)
|
||||
res = client.search(
|
||||
query,
|
||||
type=search_type,
|
||||
@ -54,7 +56,7 @@ def web_search_tool(query: str) -> str:
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
@ -65,7 +67,7 @@ def web_fetch_tool(url: str) -> str:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
try:
|
||||
client = _get_exa_client("web_fetch")
|
||||
client = _get_exa_client(resolve_context(runtime).app_config, "web_fetch")
|
||||
res = client.get_contents([url], text={"max_characters": 4096})
|
||||
|
||||
if res.results:
|
||||
|
||||
@ -1,33 +1,35 @@
|
||||
import json
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
from langchain.tools import tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
|
||||
|
||||
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
|
||||
config = AppConfig.current().get_tool_config(tool_name)
|
||||
def _get_firecrawl_client(app_config: AppConfig, tool_name: str = "web_search") -> FirecrawlApp:
|
||||
tool_config = app_config.get_tool_config(tool_name)
|
||||
api_key = None
|
||||
if config is not None and "api_key" in config.model_extra:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
if tool_config is not None and "api_key" in tool_config.model_extra:
|
||||
api_key = tool_config.model_extra.get("api_key")
|
||||
return FirecrawlApp(api_key=api_key) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str) -> str:
|
||||
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
|
||||
"""Search the web.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
try:
|
||||
config = AppConfig.current().get_tool_config("web_search")
|
||||
app_config = resolve_context(runtime).app_config
|
||||
tool_config = app_config.get_tool_config("web_search")
|
||||
max_results = 5
|
||||
if config is not None:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
if tool_config is not None:
|
||||
max_results = tool_config.model_extra.get("max_results", max_results)
|
||||
|
||||
client = _get_firecrawl_client("web_search")
|
||||
client = _get_firecrawl_client(app_config, "web_search")
|
||||
result = client.search(query, limit=max_results)
|
||||
|
||||
# result.web contains list of SearchResultWeb objects
|
||||
@ -47,7 +49,7 @@ def web_search_tool(query: str) -> str:
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
@ -58,7 +60,8 @@ def web_fetch_tool(url: str) -> str:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
try:
|
||||
client = _get_firecrawl_client("web_fetch")
|
||||
app_config = resolve_context(runtime).app_config
|
||||
client = _get_firecrawl_client(app_config, "web_fetch")
|
||||
result = client.scrape(url, formats=["markdown"])
|
||||
|
||||
markdown_content = result.markdown or ""
|
||||
|
||||
@ -5,9 +5,9 @@ Image Search Tool - Search images using DuckDuckGo for reference in image genera
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.tools import tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -77,6 +77,7 @@ def _search_images(
|
||||
@tool("image_search", parse_docstring=True)
|
||||
def image_search_tool(
|
||||
query: str,
|
||||
runtime: ToolRuntime,
|
||||
max_results: int = 5,
|
||||
size: str | None = None,
|
||||
type_image: str | None = None,
|
||||
@ -99,11 +100,11 @@ def image_search_tool(
|
||||
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
|
||||
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
|
||||
"""
|
||||
config = AppConfig.current().get_tool_config("image_search")
|
||||
tool_config = resolve_context(runtime).app_config.get_tool_config("image_search")
|
||||
|
||||
# Override max_results from config if set
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
if tool_config is not None and "max_results" in tool_config.model_extra:
|
||||
max_results = tool_config.model_extra.get("max_results", max_results)
|
||||
|
||||
results = _search_images(
|
||||
query=query,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from langchain.tools import tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
from deerflow.utils.readability import ReadabilityExtractor
|
||||
|
||||
from .infoquest_client import InfoQuestClient
|
||||
@ -8,13 +9,13 @@ from .infoquest_client import InfoQuestClient
|
||||
readability_extractor = ReadabilityExtractor()
|
||||
|
||||
|
||||
def _get_infoquest_client() -> InfoQuestClient:
|
||||
search_config = AppConfig.current().get_tool_config("web_search")
|
||||
def _get_infoquest_client(app_config: AppConfig) -> InfoQuestClient:
|
||||
search_config = app_config.get_tool_config("web_search")
|
||||
search_time_range = -1
|
||||
if search_config is not None and "search_time_range" in search_config.model_extra:
|
||||
search_time_range = search_config.model_extra.get("search_time_range")
|
||||
|
||||
fetch_config = AppConfig.current().get_tool_config("web_fetch")
|
||||
fetch_config = app_config.get_tool_config("web_fetch")
|
||||
fetch_time = -1
|
||||
if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
|
||||
fetch_time = fetch_config.model_extra.get("fetch_time")
|
||||
@ -25,7 +26,7 @@ def _get_infoquest_client() -> InfoQuestClient:
|
||||
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
|
||||
navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
|
||||
|
||||
image_search_config = AppConfig.current().get_tool_config("image_search")
|
||||
image_search_config = app_config.get_tool_config("image_search")
|
||||
image_search_time_range = -1
|
||||
if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra:
|
||||
image_search_time_range = image_search_config.model_extra.get("image_search_time_range")
|
||||
@ -44,19 +45,18 @@ def _get_infoquest_client() -> InfoQuestClient:
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str) -> str:
|
||||
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
|
||||
"""Search the web.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
|
||||
client = _get_infoquest_client()
|
||||
client = _get_infoquest_client(resolve_context(runtime).app_config)
|
||||
return client.web_search(query)
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
@ -66,7 +66,7 @@ def web_fetch_tool(url: str) -> str:
|
||||
Args:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
client = _get_infoquest_client()
|
||||
client = _get_infoquest_client(resolve_context(runtime).app_config)
|
||||
result = client.fetch(url)
|
||||
if result.startswith("Error: "):
|
||||
return result
|
||||
@ -75,7 +75,7 @@ def web_fetch_tool(url: str) -> str:
|
||||
|
||||
|
||||
@tool("image_search", parse_docstring=True)
|
||||
def image_search_tool(query: str) -> str:
|
||||
def image_search_tool(query: str, runtime: ToolRuntime) -> str:
|
||||
"""Search for images online. Use this tool BEFORE image generation to find reference images for characters, portraits, objects, scenes, or any content requiring visual accuracy.
|
||||
|
||||
**When to use:**
|
||||
@ -89,5 +89,5 @@ def image_search_tool(query: str) -> str:
|
||||
Args:
|
||||
query: The query to search for images.
|
||||
"""
|
||||
client = _get_infoquest_client()
|
||||
client = _get_infoquest_client(resolve_context(runtime).app_config)
|
||||
return client.image_search(query)
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
from langchain.tools import tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from deerflow.community.jina_ai.jina_client import JinaClient
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
from deerflow.utils.readability import ReadabilityExtractor
|
||||
|
||||
readability_extractor = ReadabilityExtractor()
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
async def web_fetch_tool(url: str) -> str:
|
||||
async def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
@ -20,9 +20,9 @@ async def web_fetch_tool(url: str) -> str:
|
||||
"""
|
||||
jina_client = JinaClient()
|
||||
timeout = 10
|
||||
config = AppConfig.current().get_tool_config("web_fetch")
|
||||
if config is not None and "timeout" in config.model_extra:
|
||||
timeout = config.model_extra.get("timeout")
|
||||
tool_config = resolve_context(runtime).app_config.get_tool_config("web_fetch")
|
||||
if tool_config is not None and "timeout" in tool_config.model_extra:
|
||||
timeout = tool_config.model_extra.get("timeout")
|
||||
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
||||
return html_content
|
||||
|
||||
@ -1,32 +1,34 @@
|
||||
import json
|
||||
|
||||
from langchain.tools import tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from tavily import TavilyClient
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
|
||||
|
||||
def _get_tavily_client() -> TavilyClient:
|
||||
config = AppConfig.current().get_tool_config("web_search")
|
||||
def _get_tavily_client(app_config: AppConfig) -> TavilyClient:
|
||||
tool_config = app_config.get_tool_config("web_search")
|
||||
api_key = None
|
||||
if config is not None and "api_key" in config.model_extra:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
if tool_config is not None and "api_key" in tool_config.model_extra:
|
||||
api_key = tool_config.model_extra.get("api_key")
|
||||
return TavilyClient(api_key=api_key)
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str) -> str:
|
||||
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
|
||||
"""Search the web.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
config = AppConfig.current().get_tool_config("web_search")
|
||||
app_config = resolve_context(runtime).app_config
|
||||
tool_config = app_config.get_tool_config("web_search")
|
||||
max_results = 5
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results")
|
||||
if tool_config is not None and "max_results" in tool_config.model_extra:
|
||||
max_results = tool_config.model_extra.get("max_results")
|
||||
|
||||
client = _get_tavily_client()
|
||||
client = _get_tavily_client(app_config)
|
||||
res = client.search(query, max_results=max_results)
|
||||
normalized_results = [
|
||||
{
|
||||
@ -41,7 +43,7 @@ def web_search_tool(query: str) -> str:
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
@ -51,7 +53,8 @@ def web_fetch_tool(url: str) -> str:
|
||||
Args:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
client = _get_tavily_client()
|
||||
app_config = resolve_context(runtime).app_config
|
||||
client = _get_tavily_client(app_config)
|
||||
res = client.extract([url])
|
||||
if "failed_results" in res and len(res["failed_results"]) > 0:
|
||||
return f"Error: {res['failed_results'][0]['error']}"
|
||||
|
||||
@ -2,9 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar, Token
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Self
|
||||
from typing import Any, Self
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
@ -222,54 +221,21 @@ class AppConfig(BaseModel):
|
||||
"""
|
||||
return next((group for group in self.tool_groups if group.name == name), None)
|
||||
|
||||
# -- Lifecycle (process-global + per-context override) --
|
||||
# AppConfig is a pure value object: construct with ``from_file()``, pass around.
|
||||
# Composition roots that hold the singleton:
|
||||
# - Gateway: ``app.state.config`` via ``Depends(get_config)``
|
||||
# - Client: ``DeerFlowClient._app_config``
|
||||
# - Agent run: ``Runtime[DeerFlowContext].context.app_config``
|
||||
#
|
||||
# _global is a plain class variable. Assignment is atomic under the GIL
|
||||
# (single pointer swap), so no lock is needed for the current read/write
|
||||
# pattern. If this ever changes to read-modify-write, add a threading.Lock.
|
||||
|
||||
_global: ClassVar[AppConfig | None] = None
|
||||
_override: ClassVar[ContextVar[AppConfig]] = ContextVar("deerflow_app_config_override")
|
||||
|
||||
@classmethod
|
||||
def init(cls, config: AppConfig) -> None:
|
||||
"""Set the process-global AppConfig. Visible to all subsequent requests."""
|
||||
cls._global = config
|
||||
|
||||
@classmethod
|
||||
def set_override(cls, config: AppConfig) -> Token[AppConfig]:
|
||||
"""Set a per-context override. Returns a token for reset_override().
|
||||
|
||||
Use this in DeerFlowClient or test fixtures to scope a config to the
|
||||
current async context without polluting the process-global value.
|
||||
"""
|
||||
return cls._override.set(config)
|
||||
|
||||
@classmethod
|
||||
def reset_override(cls, token: Token[AppConfig]) -> None:
|
||||
"""Restore the override to its previous value."""
|
||||
cls._override.reset(token)
|
||||
# ``current()`` is kept as a deprecated no-op slot purely so legacy tests
|
||||
# that still run ``patch.object(AppConfig, "current", ...)`` can attach
|
||||
# without an ``AttributeError`` at teardown. Production code never calls
|
||||
# it — any in-process invocation raises so regressions are loud.
|
||||
|
||||
@classmethod
|
||||
def current(cls) -> AppConfig:
|
||||
"""Get the current AppConfig.
|
||||
|
||||
Priority: per-context override > process-global > auto-load from file.
|
||||
|
||||
The auto-load fallback exists for backward compatibility. Prefer calling
|
||||
``AppConfig.init()`` explicitly at process startup so that config errors
|
||||
surface early rather than at an arbitrary first-use call site.
|
||||
"""
|
||||
try:
|
||||
return cls._override.get()
|
||||
except LookupError:
|
||||
pass
|
||||
if cls._global is not None:
|
||||
return cls._global
|
||||
logger.warning(
|
||||
"AppConfig.current() called before init(); auto-loading from file. "
|
||||
"Call AppConfig.init() at process startup to surface config errors early."
|
||||
raise RuntimeError(
|
||||
"AppConfig.current() is removed. Pass AppConfig explicitly: "
|
||||
"`runtime.context.app_config` in agent paths, `Depends(get_config)` in Gateway, "
|
||||
"`self._app_config` in DeerFlowClient."
|
||||
)
|
||||
config = cls.from_file()
|
||||
cls._global = config
|
||||
return config
|
||||
|
||||
@ -30,43 +30,26 @@ class DeerFlowContext:
|
||||
|
||||
|
||||
def resolve_context(runtime: Any) -> DeerFlowContext:
|
||||
"""Extract or construct DeerFlowContext from runtime.
|
||||
"""Return the typed DeerFlowContext that the runtime carries.
|
||||
|
||||
Gateway/Client paths: runtime.context is already DeerFlowContext → return directly.
|
||||
LangGraph Server / legacy dict path: construct from dict context or configurable fallback.
|
||||
Gateway mode (``DeerFlowClient``, ``run_agent``) always attaches a typed
|
||||
``DeerFlowContext`` via ``agent.astream(context=...)``; the LangGraph
|
||||
Server path uses ``langgraph.json`` registration where the top-level
|
||||
``make_lead_agent`` loads ``AppConfig`` from disk itself, so we still
|
||||
arrive here with a typed context.
|
||||
|
||||
Only the dict/None shapes that legacy tests used to exercise would fall
|
||||
through this function; we now reject them loudly instead of papering
|
||||
over the missing context with an ambient ``AppConfig`` lookup.
|
||||
"""
|
||||
ctx = getattr(runtime, "context", None)
|
||||
if isinstance(ctx, DeerFlowContext):
|
||||
return ctx
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
# Try dict context first (legacy path, tests), then configurable
|
||||
if isinstance(ctx, dict):
|
||||
thread_id = ctx.get("thread_id", "")
|
||||
if not thread_id:
|
||||
logger.warning("resolve_context: dict context has empty thread_id — may cause incorrect path resolution")
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig.current(),
|
||||
thread_id=thread_id,
|
||||
agent_name=ctx.get("agent_name"),
|
||||
)
|
||||
|
||||
# No context at all — fall back to LangGraph configurable
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
cfg = get_config().get("configurable", {})
|
||||
except RuntimeError:
|
||||
# Outside runnable context (e.g. unit tests)
|
||||
cfg = {}
|
||||
|
||||
thread_id = cfg.get("thread_id", "")
|
||||
if not thread_id:
|
||||
logger.warning("resolve_context: falling back to empty thread_id — no DeerFlowContext or configurable found")
|
||||
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig.current(),
|
||||
thread_id=thread_id,
|
||||
agent_name=cfg.get("agent_name"),
|
||||
raise RuntimeError(
|
||||
"resolve_context: runtime.context is not a DeerFlowContext "
|
||||
"(got type %s). Every entry point must attach one at invoke time — "
|
||||
"Gateway/Client via agent.astream(context=DeerFlowContext(...)), "
|
||||
"LangGraph Server via the make_lead_agent boundary that loads "
|
||||
"AppConfig.from_file()." % type(ctx).__name__
|
||||
)
|
||||
|
||||
@ -34,24 +34,18 @@ def create_chat_model(
|
||||
name: str | None = None,
|
||||
thinking_enabled: bool = False,
|
||||
*,
|
||||
app_config: "AppConfig | None" = None,
|
||||
app_config: "AppConfig",
|
||||
**kwargs,
|
||||
) -> BaseChatModel:
|
||||
"""Create a chat model instance from the config.
|
||||
|
||||
Args:
|
||||
name: The name of the model to create. If None, the first model in the config will be used.
|
||||
app_config: Application config. Falls back to AppConfig.current() when
|
||||
omitted; new callers should pass this explicitly.
|
||||
app_config: Application config — required.
|
||||
|
||||
Returns:
|
||||
A chat model instance.
|
||||
"""
|
||||
if app_config is None:
|
||||
# TODO(P2-10): fold into a required parameter once all callers
|
||||
# (memory updater, summarization middleware's implicit model) thread
|
||||
# config explicitly.
|
||||
app_config = AppConfig.current()
|
||||
config = app_config
|
||||
if name is None:
|
||||
name = config.models[0].name
|
||||
|
||||
@ -123,11 +123,11 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
async def make_checkpointer(app_config: AppConfig) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit -- no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
async with make_checkpointer(app_config) as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
@ -138,16 +138,14 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
3. Default InMemorySaver
|
||||
"""
|
||||
|
||||
config = AppConfig.current()
|
||||
|
||||
# Legacy: standalone checkpointer config takes precedence
|
||||
if config.checkpointer is not None:
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
if app_config.checkpointer is not None:
|
||||
async with _async_checkpointer(app_config.checkpointer) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
# Unified database config
|
||||
db_config = getattr(config, "database", None)
|
||||
db_config = getattr(app_config, "database", None)
|
||||
if db_config is not None and db_config.backend != "memory":
|
||||
async with _async_checkpointer_from_database(db_config) as saver:
|
||||
yield saver
|
||||
|
||||
@ -99,10 +99,13 @@ _checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
def get_checkpointer(app_config: AppConfig) -> Checkpointer:
|
||||
"""Return the global sync checkpointer singleton, creating it on first call.
|
||||
|
||||
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
Returns an ``InMemorySaver`` only when ``checkpointer`` is explicitly
|
||||
absent from config.yaml. Any other failure (missing config, invalid
|
||||
backend, connection error) propagates — silent degradation to in-memory
|
||||
would drop persistent-run state on process restart.
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
@ -113,10 +116,7 @@ def get_checkpointer() -> Checkpointer:
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
try:
|
||||
config = AppConfig.current().checkpointer
|
||||
except (LookupError, FileNotFoundError):
|
||||
config = None
|
||||
config = app_config.checkpointer
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
@ -152,25 +152,23 @@ def reset_checkpointer() -> None:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
def checkpointer_context(app_config: AppConfig) -> Iterator[Checkpointer]:
|
||||
"""Sync context manager that yields a checkpointer and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
|
||||
each ``with`` block creates and destroys its own connection. Use it in
|
||||
CLI scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with checkpointer_context() as cp:
|
||||
with checkpointer_context(app_config) as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = AppConfig.current()
|
||||
if config.checkpointer is None:
|
||||
if app_config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
with _sync_checkpointer_cm(config.checkpointer) as saver:
|
||||
with _sync_checkpointer_cm(app_config.checkpointer) as saver:
|
||||
yield saver
|
||||
|
||||
@ -169,8 +169,10 @@ async def run_agent(
|
||||
# Construct typed context for the agent run.
|
||||
# LangGraph's astream(context=...) injects this into Runtime.context
|
||||
# so middleware/tools can access it via resolve_context().
|
||||
if ctx.app_config is None:
|
||||
raise RuntimeError("RunContext.app_config is required — Gateway must populate it via get_run_context")
|
||||
deer_flow_context = DeerFlowContext(
|
||||
app_config=ctx.app_config if ctx.app_config is not None else AppConfig.current(),
|
||||
app_config=ctx.app_config,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ async def _async_store(config) -> AsyncIterator[BaseStore]:
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_store() -> AsyncIterator[BaseStore]:
|
||||
async def make_store(app_config: AppConfig) -> AsyncIterator[BaseStore]:
|
||||
"""Async context manager that yields a Store whose backend matches the
|
||||
configured checkpointer.
|
||||
|
||||
@ -94,20 +94,18 @@ async def make_store() -> AsyncIterator[BaseStore]:
|
||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
|
||||
that both singletons always use the same persistence technology::
|
||||
|
||||
async with make_store() as store:
|
||||
async with make_store(app_config) as store:
|
||||
app.state.store = store
|
||||
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
``checkpointer`` section is configured (emits a WARNING in that case).
|
||||
"""
|
||||
config = AppConfig.current()
|
||||
|
||||
if config.checkpointer is None:
|
||||
if app_config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
async with _async_store(config.checkpointer) as store:
|
||||
async with _async_store(app_config.checkpointer) as store:
|
||||
yield store
|
||||
|
||||
@ -100,7 +100,7 @@ _store: BaseStore | None = None
|
||||
_store_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_store() -> BaseStore:
|
||||
def get_store(app_config: AppConfig) -> BaseStore:
|
||||
"""Return the global sync Store singleton, creating it on first call.
|
||||
|
||||
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
@ -115,10 +115,10 @@ def get_store() -> BaseStore:
|
||||
if _store is not None:
|
||||
return _store
|
||||
|
||||
try:
|
||||
config = AppConfig.current().checkpointer
|
||||
except (LookupError, FileNotFoundError):
|
||||
config = None
|
||||
# See matching comment in checkpointer/provider.py: a missing config.yaml
|
||||
# is a deployment error, not a cue to silently pick InMemoryStore. Only
|
||||
# the explicit "no checkpointer section" path falls through to memory.
|
||||
config = app_config.checkpointer
|
||||
|
||||
if config is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
@ -154,26 +154,25 @@ def reset_store() -> None:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_context() -> Iterator[BaseStore]:
|
||||
def store_context(app_config: AppConfig) -> Iterator[BaseStore]:
|
||||
"""Sync context manager that yields a Store and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_store`, this does **not** cache the instance — each
|
||||
``with`` block creates and destroys its own connection. Use it in CLI
|
||||
scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with store_context() as store:
|
||||
with store_context(app_config) as store:
|
||||
store.put(("threads",), thread_id, {...})
|
||||
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
config = AppConfig.current()
|
||||
if config.checkpointer is None:
|
||||
if app_config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
with _sync_store_cm(config.checkpointer) as store:
|
||||
with _sync_store_cm(app_config.checkpointer) as store:
|
||||
yield store
|
||||
|
||||
@ -25,14 +25,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
|
||||
async def make_stream_bridge(app_config: AppConfig) -> AsyncIterator[StreamBridge]:
|
||||
"""Async context manager that yields a :class:`StreamBridge`.
|
||||
|
||||
Falls back to :class:`MemoryStreamBridge` when no configuration is
|
||||
provided and nothing is set globally.
|
||||
Falls back to :class:`MemoryStreamBridge` when no ``stream_bridge``
|
||||
section is configured.
|
||||
"""
|
||||
if config is None:
|
||||
config = AppConfig.current().stream_bridge
|
||||
config = app_config.stream_bridge
|
||||
|
||||
if config is None or config.type == "memory":
|
||||
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
|
||||
|
||||
@ -1,18 +1,23 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_singleton: LocalSandbox | None = None
|
||||
|
||||
|
||||
class LocalSandboxProvider(SandboxProvider):
|
||||
def __init__(self):
|
||||
def __init__(self, app_config: "AppConfig"):
|
||||
"""Initialize the local sandbox provider with path mappings."""
|
||||
self._app_config = app_config
|
||||
self._path_mappings = self._setup_path_mappings()
|
||||
|
||||
def _setup_path_mappings(self) -> list[PathMapping]:
|
||||
@ -29,9 +34,7 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
|
||||
# Map skills container path to local skills directory
|
||||
try:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
config = AppConfig.current()
|
||||
config = self._app_config
|
||||
skills_path = config.skills.get_skills_path()
|
||||
container_path = config.skills.container_path
|
||||
|
||||
|
||||
@ -43,8 +43,8 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
super().__init__()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _acquire_sandbox(self, thread_id: str) -> str:
|
||||
provider = get_sandbox_provider()
|
||||
def _acquire_sandbox(self, thread_id: str, runtime: Runtime[DeerFlowContext]) -> str:
|
||||
provider = get_sandbox_provider(runtime.context.app_config)
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
logger.info(f"Acquiring sandbox {sandbox_id}")
|
||||
return sandbox_id
|
||||
@ -60,7 +60,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
thread_id = runtime.context.thread_id
|
||||
if not thread_id:
|
||||
return super().before_agent(state, runtime)
|
||||
sandbox_id = self._acquire_sandbox(thread_id)
|
||||
sandbox_id = self._acquire_sandbox(thread_id, runtime)
|
||||
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
|
||||
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||||
return super().before_agent(state, runtime)
|
||||
@ -71,7 +71,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
if sandbox is not None:
|
||||
sandbox_id = sandbox["sandbox_id"]
|
||||
logger.info(f"Releasing sandbox {sandbox_id}")
|
||||
get_sandbox_provider().release(sandbox_id)
|
||||
get_sandbox_provider(runtime.context.app_config).release(sandbox_id)
|
||||
return None
|
||||
|
||||
# No sandbox to release
|
||||
|
||||
@ -39,23 +39,38 @@ class SandboxProvider(ABC):
|
||||
_default_sandbox_provider: SandboxProvider | None = None
|
||||
|
||||
|
||||
def get_sandbox_provider(**kwargs) -> SandboxProvider:
|
||||
def get_sandbox_provider(app_config: AppConfig, **kwargs) -> SandboxProvider:
|
||||
"""Get the sandbox provider singleton.
|
||||
|
||||
Returns a cached singleton instance. Use `reset_sandbox_provider()` to clear
|
||||
the cache, or `shutdown_sandbox_provider()` to properly shutdown and clear.
|
||||
|
||||
Args:
|
||||
app_config: Application config used the first time the singleton is built.
|
||||
Ignored on subsequent calls — the cached instance is returned
|
||||
regardless of the config passed.
|
||||
|
||||
Returns:
|
||||
A sandbox provider instance.
|
||||
"""
|
||||
global _default_sandbox_provider
|
||||
if _default_sandbox_provider is None:
|
||||
config = AppConfig.current()
|
||||
cls = resolve_class(config.sandbox.use, SandboxProvider)
|
||||
_default_sandbox_provider = cls(**kwargs)
|
||||
cls = resolve_class(app_config.sandbox.use, SandboxProvider)
|
||||
_default_sandbox_provider = cls(app_config=app_config, **kwargs) if _accepts_app_config(cls) else cls(**kwargs)
|
||||
return _default_sandbox_provider
|
||||
|
||||
|
||||
def _accepts_app_config(cls: type) -> bool:
|
||||
"""Return True when the provider's __init__ accepts an ``app_config`` kwarg."""
|
||||
import inspect
|
||||
|
||||
try:
|
||||
sig = inspect.signature(cls.__init__)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return "app_config" in sig.parameters
|
||||
|
||||
|
||||
def reset_sandbox_provider() -> None:
|
||||
"""Reset the sandbox provider singleton.
|
||||
|
||||
|
||||
@ -20,11 +20,8 @@ LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE = (
|
||||
)
|
||||
|
||||
|
||||
def uses_local_sandbox_provider(config=None) -> bool:
|
||||
def uses_local_sandbox_provider(config: AppConfig) -> bool:
|
||||
"""Return True when the active sandbox provider is the host-local provider."""
|
||||
if config is None:
|
||||
config = AppConfig.current()
|
||||
|
||||
sandbox_cfg = getattr(config, "sandbox", None)
|
||||
sandbox_use = getattr(sandbox_cfg, "use", "")
|
||||
if sandbox_use in _LOCAL_SANDBOX_PROVIDER_MARKERS:
|
||||
@ -32,11 +29,8 @@ def uses_local_sandbox_provider(config=None) -> bool:
|
||||
return sandbox_use.endswith(":LocalSandboxProvider") and "deerflow.sandbox.local" in sandbox_use
|
||||
|
||||
|
||||
def is_host_bash_allowed(config=None) -> bool:
|
||||
def is_host_bash_allowed(config: AppConfig) -> bool:
|
||||
"""Return whether host bash execution is explicitly allowed."""
|
||||
if config is None:
|
||||
config = AppConfig.current()
|
||||
|
||||
sandbox_cfg = getattr(config, "sandbox", None)
|
||||
if sandbox_cfg is None:
|
||||
return True
|
||||
|
||||
@ -40,58 +40,43 @@ _DEFAULT_GREP_MAX_RESULTS = 100
|
||||
_MAX_GREP_MAX_RESULTS = 500
|
||||
|
||||
|
||||
def _get_skills_container_path() -> str:
|
||||
"""Get the skills container path from config, with fallback to default.
|
||||
|
||||
Result is cached after the first successful config load. If config loading
|
||||
fails the default is returned *without* caching so that a later call can
|
||||
pick up the real value once the config is available.
|
||||
"""
|
||||
cached = getattr(_get_skills_container_path, "_cached", None)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
value = AppConfig.current().skills.container_path
|
||||
_get_skills_container_path._cached = value # type: ignore[attr-defined]
|
||||
return value
|
||||
except Exception:
|
||||
def _get_skills_container_path(app_config: AppConfig) -> str:
|
||||
"""Get the skills container path from config, with fallback to default."""
|
||||
skills_cfg = getattr(app_config, "skills", None)
|
||||
if skills_cfg is None:
|
||||
return _DEFAULT_SKILLS_CONTAINER_PATH
|
||||
return skills_cfg.container_path
|
||||
|
||||
|
||||
def _get_skills_host_path() -> str | None:
|
||||
def _get_skills_host_path(app_config: AppConfig) -> str | None:
|
||||
"""Get the skills host filesystem path from config.
|
||||
|
||||
Returns None if the skills directory does not exist or config cannot be
|
||||
loaded. Only successful lookups are cached; failures are retried on the
|
||||
next call so that a transiently unavailable skills directory does not
|
||||
permanently disable skills access.
|
||||
Returns None if the skills directory does not exist or is not configured.
|
||||
"""
|
||||
cached = getattr(_get_skills_host_path, "_cached", None)
|
||||
if cached is not None:
|
||||
return cached
|
||||
skills_cfg = getattr(app_config, "skills", None)
|
||||
if skills_cfg is None:
|
||||
return None
|
||||
try:
|
||||
config = AppConfig.current()
|
||||
skills_path = config.skills.get_skills_path()
|
||||
if skills_path.exists():
|
||||
value = str(skills_path)
|
||||
_get_skills_host_path._cached = value # type: ignore[attr-defined]
|
||||
return value
|
||||
skills_path = skills_cfg.get_skills_path()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
if skills_path.exists():
|
||||
return str(skills_path)
|
||||
return None
|
||||
|
||||
|
||||
def _is_skills_path(path: str) -> bool:
|
||||
def _is_skills_path(path: str, app_config: AppConfig) -> bool:
|
||||
"""Check if a path is under the skills container path."""
|
||||
skills_prefix = _get_skills_container_path()
|
||||
skills_prefix = _get_skills_container_path(app_config)
|
||||
return path == skills_prefix or path.startswith(f"{skills_prefix}/")
|
||||
|
||||
|
||||
def _resolve_skills_path(path: str) -> str:
|
||||
def _resolve_skills_path(path: str, app_config: AppConfig) -> str:
|
||||
"""Resolve a virtual skills path to a host filesystem path.
|
||||
|
||||
Args:
|
||||
path: Virtual skills path (e.g. /mnt/skills/public/bootstrap/SKILL.md)
|
||||
app_config: Resolved application config.
|
||||
|
||||
Returns:
|
||||
Resolved host path.
|
||||
@ -99,8 +84,8 @@ def _resolve_skills_path(path: str) -> str:
|
||||
Raises:
|
||||
FileNotFoundError: If skills directory is not configured or doesn't exist.
|
||||
"""
|
||||
skills_container = _get_skills_container_path()
|
||||
skills_host = _get_skills_host_path()
|
||||
skills_container = _get_skills_container_path(app_config)
|
||||
skills_host = _get_skills_host_path(app_config)
|
||||
if skills_host is None:
|
||||
raise FileNotFoundError(f"Skills directory not available for path: {path}")
|
||||
|
||||
@ -116,46 +101,31 @@ def _is_acp_workspace_path(path: str) -> bool:
|
||||
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
|
||||
|
||||
|
||||
def _get_custom_mounts():
|
||||
def _get_custom_mounts(app_config: AppConfig):
|
||||
"""Get custom volume mounts from sandbox config.
|
||||
|
||||
Result is cached after the first successful config load. If config loading
|
||||
fails an empty list is returned *without* caching so that a later call can
|
||||
pick up the real value once the config is available.
|
||||
Only includes mounts whose host_path exists, consistent with
|
||||
``LocalSandboxProvider._setup_path_mappings()`` which also filters by
|
||||
``host_path.exists()``.
|
||||
"""
|
||||
cached = getattr(_get_custom_mounts, "_cached", None)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
config = AppConfig.current()
|
||||
mounts = []
|
||||
if config.sandbox and config.sandbox.mounts:
|
||||
# Only include mounts whose host_path exists, consistent with
|
||||
# LocalSandboxProvider._setup_path_mappings() which also filters
|
||||
# by host_path.exists().
|
||||
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
|
||||
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
|
||||
return mounts
|
||||
except Exception:
|
||||
# If config loading fails, return an empty list without caching so that
|
||||
# a later call can retry once the config is available.
|
||||
sandbox_cfg = getattr(app_config, "sandbox", None)
|
||||
if sandbox_cfg is None or not sandbox_cfg.mounts:
|
||||
return []
|
||||
return [m for m in sandbox_cfg.mounts if Path(m.host_path).exists()]
|
||||
|
||||
|
||||
def _is_custom_mount_path(path: str) -> bool:
|
||||
def _is_custom_mount_path(path: str, app_config: AppConfig) -> bool:
|
||||
"""Check if path is under a custom mount container_path."""
|
||||
for mount in _get_custom_mounts():
|
||||
for mount in _get_custom_mounts(app_config):
|
||||
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_custom_mount_for_path(path: str):
|
||||
def _get_custom_mount_for_path(path: str, app_config: AppConfig):
|
||||
"""Get the mount config matching this path (longest prefix first)."""
|
||||
best = None
|
||||
for mount in _get_custom_mounts():
|
||||
for mount in _get_custom_mounts(app_config):
|
||||
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
|
||||
if best is None or len(mount.container_path) > len(best.container_path):
|
||||
best = mount
|
||||
@ -266,42 +236,40 @@ def _resolve_acp_workspace_path(path: str, thread_id: str | None = None) -> str:
|
||||
return str(resolved_path)
|
||||
|
||||
|
||||
def _get_mcp_allowed_paths() -> list[str]:
|
||||
def _get_mcp_allowed_paths(app_config: AppConfig) -> list[str]:
|
||||
"""Get the list of allowed paths from MCP config for file system server."""
|
||||
allowed_paths = []
|
||||
try:
|
||||
extensions_config = AppConfig.current().extensions
|
||||
allowed_paths: list[str] = []
|
||||
extensions_config = getattr(app_config, "extensions", None)
|
||||
if extensions_config is None:
|
||||
return allowed_paths
|
||||
|
||||
for _, server in extensions_config.mcp_servers.items():
|
||||
if not server.enabled:
|
||||
continue
|
||||
for _, server in extensions_config.mcp_servers.items():
|
||||
if not server.enabled:
|
||||
continue
|
||||
|
||||
# Only check the filesystem server
|
||||
args = server.args or []
|
||||
# Check if args has server-filesystem package
|
||||
has_filesystem = any("server-filesystem" in arg for arg in args)
|
||||
if not has_filesystem:
|
||||
continue
|
||||
# Unpack the allowed file system paths in config
|
||||
for arg in args:
|
||||
if not arg.startswith("-") and arg.startswith("/"):
|
||||
allowed_paths.append(arg.rstrip("/") + "/")
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
# Only check the filesystem server
|
||||
args = server.args or []
|
||||
# Check if args has server-filesystem package
|
||||
has_filesystem = any("server-filesystem" in arg for arg in args)
|
||||
if not has_filesystem:
|
||||
continue
|
||||
# Unpack the allowed file system paths in config
|
||||
for arg in args:
|
||||
if not arg.startswith("-") and arg.startswith("/"):
|
||||
allowed_paths.append(arg.rstrip("/") + "/")
|
||||
|
||||
return allowed_paths
|
||||
|
||||
|
||||
def _get_tool_config_int(name: str, key: str, default: int) -> int:
|
||||
def _get_tool_config_int(app_config: AppConfig, name: str, key: str, default: int) -> int:
|
||||
try:
|
||||
tool_config = AppConfig.current().get_tool_config(name)
|
||||
if tool_config is not None and key in tool_config.model_extra:
|
||||
value = tool_config.model_extra.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
tool_config = app_config.get_tool_config(name)
|
||||
except Exception:
|
||||
pass
|
||||
return default
|
||||
if tool_config is not None and key in tool_config.model_extra:
|
||||
value = tool_config.model_extra.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
@ -311,23 +279,23 @@ def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
|
||||
return min(value, upper_bound)
|
||||
|
||||
|
||||
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int:
|
||||
def _resolve_max_results(app_config: AppConfig, name: str, requested: int, *, default: int, upper_bound: int) -> int:
|
||||
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
|
||||
configured_max_results = _clamp_max_results(
|
||||
_get_tool_config_int(name, "max_results", default),
|
||||
_get_tool_config_int(app_config, name, "max_results", default),
|
||||
default=default,
|
||||
upper_bound=upper_bound,
|
||||
)
|
||||
return min(requested_max_results, configured_max_results)
|
||||
|
||||
|
||||
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str:
|
||||
validate_local_tool_path(path, thread_data, read_only=True)
|
||||
if _is_skills_path(path):
|
||||
return _resolve_skills_path(path)
|
||||
def _resolve_local_read_path(path: str, thread_data: ThreadDataState, app_config: AppConfig) -> str:
|
||||
validate_local_tool_path(path, thread_data, app_config, read_only=True)
|
||||
if _is_skills_path(path, app_config):
|
||||
return _resolve_skills_path(path, app_config)
|
||||
if _is_acp_workspace_path(path):
|
||||
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
return _resolve_and_validate_user_data_path(path, thread_data)
|
||||
return _resolve_and_validate_user_data_path(path, thread_data, app_config)
|
||||
|
||||
|
||||
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
|
||||
@ -373,7 +341,11 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||
return f"{stripped_base}{separator}{normalized_relative}"
|
||||
|
||||
|
||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||
def _sanitize_error(
|
||||
error: Exception,
|
||||
runtime: "ToolRuntime[ContextT, ThreadState] | None" = None,
|
||||
app_config: AppConfig | None = None,
|
||||
) -> str:
|
||||
"""Sanitize an error message to avoid leaking host filesystem paths.
|
||||
|
||||
In local-sandbox mode, resolved host paths in the error string are masked
|
||||
@ -382,8 +354,12 @@ def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadStat
|
||||
"""
|
||||
msg = f"{type(error).__name__}: {error}"
|
||||
if runtime is not None and is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
msg = mask_local_paths_in_output(msg, thread_data)
|
||||
if app_config is None:
|
||||
ctx = getattr(runtime, "context", None)
|
||||
app_config = getattr(ctx, "app_config", None)
|
||||
if app_config is not None:
|
||||
thread_data = get_thread_data(runtime)
|
||||
msg = mask_local_paths_in_output(msg, thread_data, app_config)
|
||||
return msg
|
||||
|
||||
|
||||
@ -453,7 +429,7 @@ def _thread_actual_to_virtual_mappings(thread_data: ThreadDataState) -> dict[str
|
||||
return {actual: virtual for virtual, actual in _thread_virtual_to_actual_mappings(thread_data).items()}
|
||||
|
||||
|
||||
def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None) -> str:
|
||||
def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> str:
|
||||
"""Mask host absolute paths from local sandbox output using virtual paths.
|
||||
|
||||
Handles user-data paths (per-thread), skills paths, and ACP workspace paths (global).
|
||||
@ -461,8 +437,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
|
||||
result = output
|
||||
|
||||
# Mask skills host paths
|
||||
skills_host = _get_skills_host_path()
|
||||
skills_container = _get_skills_container_path()
|
||||
skills_host = _get_skills_host_path(app_config)
|
||||
skills_container = _get_skills_container_path(app_config)
|
||||
if skills_host:
|
||||
raw_base = str(Path(skills_host))
|
||||
resolved_base = str(Path(skills_host).resolve())
|
||||
@ -536,7 +512,13 @@ def _reject_path_traversal(path: str) -> None:
|
||||
raise PermissionError("Access denied: path traversal detected")
|
||||
|
||||
|
||||
def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *, read_only: bool = False) -> None:
|
||||
def validate_local_tool_path(
|
||||
path: str,
|
||||
thread_data: ThreadDataState | None,
|
||||
app_config: AppConfig,
|
||||
*,
|
||||
read_only: bool = False,
|
||||
) -> None:
|
||||
"""Validate that a virtual path is allowed for local-sandbox access.
|
||||
|
||||
This function is a security gate — it checks whether *path* may be
|
||||
@ -565,7 +547,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
|
||||
_reject_path_traversal(path)
|
||||
|
||||
# Skills paths — read-only access only
|
||||
if _is_skills_path(path):
|
||||
if _is_skills_path(path, app_config):
|
||||
if not read_only:
|
||||
raise PermissionError(f"Write access to skills path is not allowed: {path}")
|
||||
return
|
||||
@ -581,13 +563,13 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
|
||||
return
|
||||
|
||||
# Custom mount paths — respect read_only config
|
||||
if _is_custom_mount_path(path):
|
||||
mount = _get_custom_mount_for_path(path)
|
||||
if _is_custom_mount_path(path, app_config):
|
||||
mount = _get_custom_mount_for_path(path, app_config)
|
||||
if mount and mount.read_only and not read_only:
|
||||
raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
|
||||
return
|
||||
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path(app_config)}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
|
||||
|
||||
|
||||
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
|
||||
@ -618,18 +600,23 @@ def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataSta
|
||||
raise PermissionError("Access denied: path traversal detected")
|
||||
|
||||
|
||||
def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState) -> str:
|
||||
def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState, app_config: AppConfig) -> str:
|
||||
"""Resolve a /mnt/user-data virtual path and validate it stays in bounds.
|
||||
|
||||
Returns the resolved host path string.
|
||||
|
||||
``app_config`` is accepted for signature symmetry with the other resolver
|
||||
helpers; the user-data resolution path itself is fully derivable from
|
||||
``thread_data``.
|
||||
"""
|
||||
_ = app_config # noqa: F841 — kept for interface symmetry with sibling resolvers
|
||||
resolved_str = replace_virtual_path(path, thread_data)
|
||||
resolved = Path(resolved_str).resolve()
|
||||
_validate_resolved_user_data_path(resolved, thread_data)
|
||||
return str(resolved)
|
||||
|
||||
|
||||
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None) -> None:
|
||||
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> None:
|
||||
"""Validate absolute paths in local-sandbox bash commands.
|
||||
|
||||
This validation is only a best-effort guard for the explicit
|
||||
@ -653,7 +640,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
raise PermissionError(f"Unsafe file:// URL in command: {file_url_match.group()}. Use paths under {VIRTUAL_PATH_PREFIX}")
|
||||
|
||||
unsafe_paths: list[str] = []
|
||||
allowed_paths = _get_mcp_allowed_paths()
|
||||
allowed_paths = _get_mcp_allowed_paths(app_config)
|
||||
|
||||
for absolute_path in _ABSOLUTE_PATH_PATTERN.findall(command):
|
||||
# Check for MCP filesystem server allowed paths
|
||||
@ -666,7 +653,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
continue
|
||||
|
||||
# Allow skills container path (resolved by tools.py before passing to sandbox)
|
||||
if _is_skills_path(absolute_path):
|
||||
if _is_skills_path(absolute_path, app_config):
|
||||
_reject_path_traversal(absolute_path)
|
||||
continue
|
||||
|
||||
@ -676,7 +663,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
continue
|
||||
|
||||
# Allow custom mount container paths
|
||||
if _is_custom_mount_path(absolute_path):
|
||||
if _is_custom_mount_path(absolute_path, app_config):
|
||||
_reject_path_traversal(absolute_path)
|
||||
continue
|
||||
|
||||
@ -690,12 +677,13 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
raise PermissionError(f"Unsafe absolute paths in command: {unsafe}. Use paths under {VIRTUAL_PATH_PREFIX}")
|
||||
|
||||
|
||||
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None) -> str:
|
||||
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> str:
|
||||
"""Replace all virtual paths (/mnt/user-data, /mnt/skills, /mnt/acp-workspace) in a command string.
|
||||
|
||||
Args:
|
||||
command: The command string that may contain virtual paths.
|
||||
thread_data: The thread data containing actual paths.
|
||||
app_config: Resolved application config.
|
||||
|
||||
Returns:
|
||||
The command with all virtual paths replaced.
|
||||
@ -703,13 +691,13 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
|
||||
result = command
|
||||
|
||||
# Replace skills paths
|
||||
skills_container = _get_skills_container_path()
|
||||
skills_host = _get_skills_host_path()
|
||||
skills_container = _get_skills_container_path(app_config)
|
||||
skills_host = _get_skills_host_path(app_config)
|
||||
if skills_host and skills_container in result:
|
||||
skills_pattern = re.compile(rf"{re.escape(skills_container)}(/[^\s\"';&|<>()]*)?")
|
||||
|
||||
def replace_skills_match(match: re.Match) -> str:
|
||||
return _resolve_skills_path(match.group(0))
|
||||
return _resolve_skills_path(match.group(0), app_config)
|
||||
|
||||
result = skills_pattern.sub(replace_skills_match, result)
|
||||
|
||||
@ -799,7 +787,7 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
|
||||
sandbox_id = sandbox_state.get("sandbox_id")
|
||||
if sandbox_id is None:
|
||||
raise SandboxRuntimeError("Sandbox ID not found in state")
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
sandbox = get_sandbox_provider(resolve_context(runtime).app_config).get(sandbox_id)
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
|
||||
|
||||
@ -830,12 +818,14 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
if runtime.state is None:
|
||||
raise SandboxRuntimeError("Tool runtime state not available")
|
||||
|
||||
app_config = runtime.context.app_config
|
||||
|
||||
# Check if sandbox already exists in state
|
||||
sandbox_state = runtime.state.get("sandbox")
|
||||
if sandbox_state is not None:
|
||||
sandbox_id = sandbox_state.get("sandbox_id")
|
||||
if sandbox_id is not None:
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
sandbox = get_sandbox_provider(app_config).get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
return sandbox
|
||||
# Sandbox was released, fall through to acquire new one
|
||||
@ -845,7 +835,7 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
if not thread_id:
|
||||
raise SandboxRuntimeError("Thread ID not available in runtime context")
|
||||
|
||||
provider = get_sandbox_provider()
|
||||
provider = get_sandbox_provider(app_config)
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
|
||||
# Update runtime state - this persists across tool calls
|
||||
@ -985,9 +975,9 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
command: The bash command to execute. Always use absolute paths for files and directories.
|
||||
"""
|
||||
app_config = resolve_context(runtime).app_config
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
app_config = resolve_context(runtime).app_config
|
||||
sandbox_cfg = app_config.sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
if is_local_sandbox(runtime):
|
||||
@ -995,11 +985,11 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
|
||||
ensure_thread_directories_exist(runtime)
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_bash_command_paths(command, thread_data)
|
||||
command = replace_virtual_paths_in_command(command, thread_data)
|
||||
validate_local_bash_command_paths(command, thread_data, app_config)
|
||||
command = replace_virtual_paths_in_command(command, thread_data, app_config)
|
||||
command = _apply_cwd_prefix(command, thread_data)
|
||||
output = sandbox.execute_command(command)
|
||||
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
|
||||
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data, app_config), max_chars)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
|
||||
except SandboxError as e:
|
||||
@ -1007,7 +997,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}"
|
||||
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime, app_config)}"
|
||||
|
||||
|
||||
@tool("ls", parse_docstring=True)
|
||||
@ -1018,25 +1008,26 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
path: The **absolute** path to the directory to list.
|
||||
"""
|
||||
app_config = resolve_context(runtime).app_config
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data, read_only=True)
|
||||
if _is_skills_path(path):
|
||||
path = _resolve_skills_path(path)
|
||||
validate_local_tool_path(path, thread_data, app_config, read_only=True)
|
||||
if _is_skills_path(path, app_config):
|
||||
path = _resolve_skills_path(path, app_config)
|
||||
elif _is_acp_workspace_path(path):
|
||||
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
elif not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
elif not _is_custom_mount_path(path, app_config):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
children = sandbox.list_dir(path)
|
||||
if not children:
|
||||
return "(empty)"
|
||||
output = "\n".join(children)
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
sandbox_cfg = app_config.sandbox
|
||||
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
|
||||
return _truncate_ls_output(output, max_chars)
|
||||
except SandboxError as e:
|
||||
@ -1046,7 +1037,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
|
||||
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime, app_config)}"
|
||||
|
||||
|
||||
@tool("glob", parse_docstring=True)
|
||||
@ -1067,11 +1058,13 @@ def glob_tool(
|
||||
include_dirs: Whether matching directories should also be returned. Default is False.
|
||||
max_results: Maximum number of paths to return. Default is 200.
|
||||
"""
|
||||
app_config = resolve_context(runtime).app_config
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
effective_max_results = _resolve_max_results(
|
||||
app_config,
|
||||
"glob",
|
||||
max_results,
|
||||
default=_DEFAULT_GLOB_MAX_RESULTS,
|
||||
@ -1082,10 +1075,10 @@ def glob_tool(
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
path = _resolve_local_read_path(path, thread_data)
|
||||
path = _resolve_local_read_path(path, thread_data, app_config)
|
||||
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
|
||||
if thread_data is not None:
|
||||
matches = [mask_local_paths_in_output(match, thread_data) for match in matches]
|
||||
matches = [mask_local_paths_in_output(match, thread_data, app_config) for match in matches]
|
||||
return _format_glob_results(requested_path, matches, truncated)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
@ -1096,7 +1089,7 @@ def glob_tool(
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
|
||||
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime, app_config)}"
|
||||
|
||||
|
||||
@tool("grep", parse_docstring=True)
|
||||
@ -1121,11 +1114,13 @@ def grep_tool(
|
||||
case_sensitive: Whether matching is case-sensitive. Default is False.
|
||||
max_results: Maximum number of matching lines to return. Default is 100.
|
||||
"""
|
||||
app_config = resolve_context(runtime).app_config
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
effective_max_results = _resolve_max_results(
|
||||
app_config,
|
||||
"grep",
|
||||
max_results,
|
||||
default=_DEFAULT_GREP_MAX_RESULTS,
|
||||
@ -1136,7 +1131,7 @@ def grep_tool(
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
path = _resolve_local_read_path(path, thread_data)
|
||||
path = _resolve_local_read_path(path, thread_data, app_config)
|
||||
matches, truncated = sandbox.grep(
|
||||
path,
|
||||
pattern,
|
||||
@ -1148,7 +1143,7 @@ def grep_tool(
|
||||
if thread_data is not None:
|
||||
matches = [
|
||||
GrepMatch(
|
||||
path=mask_local_paths_in_output(match.path, thread_data),
|
||||
path=mask_local_paths_in_output(match.path, thread_data, app_config),
|
||||
line_number=match.line_number,
|
||||
line=match.line,
|
||||
)
|
||||
@ -1166,7 +1161,7 @@ def grep_tool(
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
|
||||
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime, app_config)}"
|
||||
|
||||
|
||||
@tool("read_file", parse_docstring=True)
|
||||
@ -1185,26 +1180,27 @@ def read_file_tool(
|
||||
start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range.
|
||||
end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range.
|
||||
"""
|
||||
app_config = resolve_context(runtime).app_config
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data, read_only=True)
|
||||
if _is_skills_path(path):
|
||||
path = _resolve_skills_path(path)
|
||||
validate_local_tool_path(path, thread_data, app_config, read_only=True)
|
||||
if _is_skills_path(path, app_config):
|
||||
path = _resolve_skills_path(path, app_config)
|
||||
elif _is_acp_workspace_path(path):
|
||||
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
elif not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
elif not _is_custom_mount_path(path, app_config):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "(empty)"
|
||||
if start_line is not None and end_line is not None:
|
||||
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
sandbox_cfg = app_config.sandbox
|
||||
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
|
||||
return _truncate_read_file_output(content, max_chars)
|
||||
except SandboxError as e:
|
||||
@ -1216,7 +1212,7 @@ def read_file_tool(
|
||||
except IsADirectoryError:
|
||||
return f"Error: Path is a directory, not a file: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}"
|
||||
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime, app_config)}"
|
||||
|
||||
|
||||
@tool("write_file", parse_docstring=True)
|
||||
@ -1234,15 +1230,16 @@ def write_file_tool(
|
||||
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
"""
|
||||
app_config = resolve_context(runtime).app_config
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
if not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
validate_local_tool_path(path, thread_data, app_config)
|
||||
if not _is_custom_mount_path(path, app_config):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
with get_file_operation_lock(sandbox, path):
|
||||
sandbox.write_file(path, content, append)
|
||||
@ -1254,9 +1251,9 @@ def write_file_tool(
|
||||
except IsADirectoryError:
|
||||
return f"Error: Path is a directory, not a file: {requested_path}"
|
||||
except OSError as e:
|
||||
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}"
|
||||
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime, app_config)}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
|
||||
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime, app_config)}"
|
||||
|
||||
|
||||
@tool("str_replace", parse_docstring=True)
|
||||
@ -1278,15 +1275,16 @@ def str_replace_tool(
|
||||
new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH.
|
||||
replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False.
|
||||
"""
|
||||
app_config = resolve_context(runtime).app_config
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
if not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
validate_local_tool_path(path, thread_data, app_config)
|
||||
if not _is_custom_mount_path(path, app_config):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
with get_file_operation_lock(sandbox, path):
|
||||
content = sandbox.read_file(path)
|
||||
@ -1307,4 +1305,4 @@ def str_replace_tool(
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied accessing file: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}"
|
||||
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime, app_config)}"
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .parser import parse_skill_file
|
||||
from .types import Skill
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -22,7 +26,12 @@ def get_skills_root_path() -> Path:
|
||||
return skills_dir
|
||||
|
||||
|
||||
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
|
||||
def load_skills(
|
||||
app_config: "AppConfig | None" = None,
|
||||
*,
|
||||
skills_path: Path | None = None,
|
||||
enabled_only: bool = False,
|
||||
) -> list[Skill]:
|
||||
"""
|
||||
Load all skills from the skills directory.
|
||||
|
||||
@ -30,25 +39,19 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
|
||||
to extract metadata. The enabled state is determined by the skills_state_config.json file.
|
||||
|
||||
Args:
|
||||
skills_path: Optional custom path to skills directory.
|
||||
If not provided and use_config is True, uses path from config.
|
||||
Otherwise defaults to deer-flow/skills
|
||||
use_config: Whether to load skills path from config (default: True)
|
||||
app_config: Application config used to resolve the configured skills
|
||||
directory. Ignored when ``skills_path`` is supplied.
|
||||
skills_path: Explicit override for the skills directory. When both
|
||||
``skills_path`` and ``app_config`` are omitted the
|
||||
default repository layout is used (``deer-flow/skills``).
|
||||
enabled_only: If True, only return enabled skills (default: False)
|
||||
|
||||
Returns:
|
||||
List of Skill objects, sorted by name
|
||||
"""
|
||||
if skills_path is None:
|
||||
if use_config:
|
||||
try:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
config = AppConfig.current()
|
||||
skills_path = config.skills.get_skills_path()
|
||||
except Exception:
|
||||
# Fallback to default if config fails
|
||||
skills_path = get_skills_root_path()
|
||||
if app_config is not None:
|
||||
skills_path = app_config.skills.get_skills_path()
|
||||
else:
|
||||
skills_path = get_skills_root_path()
|
||||
|
||||
|
||||
@ -20,16 +20,17 @@ ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
|
||||
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
|
||||
|
||||
|
||||
def get_skills_root_dir() -> Path:
|
||||
return AppConfig.current().skills.get_skills_path()
|
||||
def get_skills_root_dir(app_config: AppConfig) -> Path:
|
||||
"""Return the configured skills root."""
|
||||
return app_config.skills.get_skills_path()
|
||||
|
||||
|
||||
def get_public_skills_dir() -> Path:
|
||||
return get_skills_root_dir() / "public"
|
||||
def get_public_skills_dir(app_config: AppConfig) -> Path:
|
||||
return get_skills_root_dir(app_config) / "public"
|
||||
|
||||
|
||||
def get_custom_skills_dir() -> Path:
|
||||
path = get_skills_root_dir() / "custom"
|
||||
def get_custom_skills_dir(app_config: AppConfig) -> Path:
|
||||
path = get_skills_root_dir(app_config) / "custom"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
@ -43,46 +44,46 @@ def validate_skill_name(name: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
def get_custom_skill_dir(name: str) -> Path:
|
||||
return get_custom_skills_dir() / validate_skill_name(name)
|
||||
def get_custom_skill_dir(name: str, app_config: AppConfig) -> Path:
|
||||
return get_custom_skills_dir(app_config) / validate_skill_name(name)
|
||||
|
||||
|
||||
def get_custom_skill_file(name: str) -> Path:
|
||||
return get_custom_skill_dir(name) / SKILL_FILE_NAME
|
||||
def get_custom_skill_file(name: str, app_config: AppConfig) -> Path:
|
||||
return get_custom_skill_dir(name, app_config) / SKILL_FILE_NAME
|
||||
|
||||
|
||||
def get_custom_skill_history_dir() -> Path:
|
||||
path = get_custom_skills_dir() / HISTORY_DIR_NAME
|
||||
def get_custom_skill_history_dir(app_config: AppConfig) -> Path:
|
||||
path = get_custom_skills_dir(app_config) / HISTORY_DIR_NAME
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_skill_history_file(name: str) -> Path:
|
||||
return get_custom_skill_history_dir() / f"{validate_skill_name(name)}.jsonl"
|
||||
def get_skill_history_file(name: str, app_config: AppConfig) -> Path:
|
||||
return get_custom_skill_history_dir(app_config) / f"{validate_skill_name(name)}.jsonl"
|
||||
|
||||
|
||||
def get_public_skill_dir(name: str) -> Path:
|
||||
return get_public_skills_dir() / validate_skill_name(name)
|
||||
def get_public_skill_dir(name: str, app_config: AppConfig) -> Path:
|
||||
return get_public_skills_dir(app_config) / validate_skill_name(name)
|
||||
|
||||
|
||||
def custom_skill_exists(name: str) -> bool:
|
||||
return get_custom_skill_file(name).exists()
|
||||
def custom_skill_exists(name: str, app_config: AppConfig) -> bool:
|
||||
return get_custom_skill_file(name, app_config).exists()
|
||||
|
||||
|
||||
def public_skill_exists(name: str) -> bool:
|
||||
return (get_public_skill_dir(name) / SKILL_FILE_NAME).exists()
|
||||
def public_skill_exists(name: str, app_config: AppConfig) -> bool:
|
||||
return (get_public_skill_dir(name, app_config) / SKILL_FILE_NAME).exists()
|
||||
|
||||
|
||||
def ensure_custom_skill_is_editable(name: str) -> None:
|
||||
if custom_skill_exists(name):
|
||||
def ensure_custom_skill_is_editable(name: str, app_config: AppConfig) -> None:
|
||||
if custom_skill_exists(name, app_config):
|
||||
return
|
||||
if public_skill_exists(name):
|
||||
if public_skill_exists(name, app_config):
|
||||
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
|
||||
raise FileNotFoundError(f"Custom skill '{name}' not found.")
|
||||
|
||||
|
||||
def ensure_safe_support_path(name: str, relative_path: str) -> Path:
|
||||
skill_dir = get_custom_skill_dir(name).resolve()
|
||||
def ensure_safe_support_path(name: str, relative_path: str, app_config: AppConfig) -> Path:
|
||||
skill_dir = get_custom_skill_dir(name, app_config).resolve()
|
||||
if not relative_path or relative_path.endswith("/"):
|
||||
raise ValueError("Supporting file path must include a filename.")
|
||||
relative = Path(relative_path)
|
||||
@ -124,8 +125,8 @@ def atomic_write(path: Path, content: str) -> None:
|
||||
tmp_path.replace(path)
|
||||
|
||||
|
||||
def append_history(name: str, record: dict[str, Any]) -> None:
|
||||
history_path = get_skill_history_file(name)
|
||||
def append_history(name: str, record: dict[str, Any], app_config: AppConfig) -> None:
|
||||
history_path = get_skill_history_file(name, app_config)
|
||||
history_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"ts": datetime.now(UTC).isoformat(),
|
||||
@ -136,8 +137,8 @@ def append_history(name: str, record: dict[str, Any]) -> None:
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def read_history(name: str) -> list[dict[str, Any]]:
|
||||
history_path = get_skill_history_file(name)
|
||||
def read_history(name: str, app_config: AppConfig) -> list[dict[str, Any]]:
|
||||
history_path = get_skill_history_file(name, app_config)
|
||||
if not history_path.exists():
|
||||
return []
|
||||
records: list[dict[str, Any]] = []
|
||||
@ -148,12 +149,12 @@ def read_history(name: str) -> list[dict[str, Any]]:
|
||||
return records
|
||||
|
||||
|
||||
def list_custom_skills() -> list:
|
||||
return [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
|
||||
def list_custom_skills(app_config: AppConfig) -> list:
|
||||
return [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
|
||||
|
||||
|
||||
def read_custom_skill_content(name: str) -> str:
|
||||
skill_file = get_custom_skill_file(name)
|
||||
def read_custom_skill_content(name: str, app_config: AppConfig) -> str:
|
||||
skill_file = get_custom_skill_file(name, app_config)
|
||||
if not skill_file.exists():
|
||||
raise FileNotFoundError(f"Custom skill '{name}' not found.")
|
||||
return skill_file.read_text(encoding="utf-8")
|
||||
|
||||
@ -35,7 +35,7 @@ def _extract_json_object(raw: str) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
|
||||
async def scan_skill_content(app_config: AppConfig, content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
|
||||
"""Screen skill content before it is written to disk."""
|
||||
rubric = (
|
||||
"You are a security reviewer for AI agent skills. "
|
||||
@ -47,9 +47,12 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
|
||||
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
|
||||
|
||||
try:
|
||||
config = AppConfig.current()
|
||||
model_name = config.skill_evolution.moderation_model_name
|
||||
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
|
||||
model_name = app_config.skill_evolution.moderation_model_name
|
||||
model = (
|
||||
create_chat_model(name=model_name, thinking_enabled=False, app_config=app_config)
|
||||
if model_name
|
||||
else create_chat_model(thinking_enabled=False, app_config=app_config)
|
||||
)
|
||||
response = await model.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": rubric},
|
||||
|
||||
@ -17,6 +17,7 @@ from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
|
||||
@ -132,24 +133,16 @@ class SubagentExecutor:
|
||||
self,
|
||||
config: SubagentConfig,
|
||||
tools: list[BaseTool],
|
||||
app_config: AppConfig,
|
||||
parent_model: str | None = None,
|
||||
sandbox_state: SandboxState | None = None,
|
||||
thread_data: ThreadDataState | None = None,
|
||||
thread_id: str | None = None,
|
||||
trace_id: str | None = None,
|
||||
):
|
||||
"""Initialize the executor.
|
||||
|
||||
Args:
|
||||
config: Subagent configuration.
|
||||
tools: List of all available tools (will be filtered).
|
||||
parent_model: The parent agent's model name for inheritance.
|
||||
sandbox_state: Sandbox state from parent agent.
|
||||
thread_data: Thread data from parent agent.
|
||||
thread_id: Thread ID for sandbox operations.
|
||||
trace_id: Trace ID from parent for distributed tracing.
|
||||
"""
|
||||
"""Initialize the executor."""
|
||||
self.config = config
|
||||
self.app_config = app_config
|
||||
self.parent_model = parent_model
|
||||
self.sandbox_state = sandbox_state
|
||||
self.thread_data = thread_data
|
||||
@ -169,7 +162,7 @@ class SubagentExecutor:
|
||||
def _create_agent(self):
|
||||
"""Create the agent instance."""
|
||||
model_name = _get_model_name(self.config, self.parent_model)
|
||||
model = create_chat_model(name=model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=model_name, thinking_enabled=False, app_config=self.app_config)
|
||||
|
||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
@ -10,25 +11,15 @@ from deerflow.subagents.config import SubagentConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
"""Get a subagent configuration by name, with config.yaml overrides applied.
|
||||
|
||||
Args:
|
||||
name: The name of the subagent.
|
||||
|
||||
Returns:
|
||||
SubagentConfig if found (with any config.yaml overrides applied), None otherwise.
|
||||
"""
|
||||
def get_subagent_config(name: str, app_config: AppConfig) -> SubagentConfig | None:
|
||||
"""Get a subagent configuration by name, with config.yaml overrides applied."""
|
||||
config = BUILTIN_SUBAGENTS.get(name)
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
# Apply timeout override from config.yaml (lazy import to avoid circular deps)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
app_config = AppConfig.current().subagents
|
||||
effective_timeout = app_config.get_timeout_for(name)
|
||||
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)
|
||||
sub_config = app_config.subagents
|
||||
effective_timeout = sub_config.get_timeout_for(name)
|
||||
effective_max_turns = sub_config.get_max_turns_for(name, config.max_turns)
|
||||
|
||||
overrides = {}
|
||||
if effective_timeout != config.timeout_seconds:
|
||||
@ -53,13 +44,13 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
return config
|
||||
|
||||
|
||||
def list_subagents() -> list[SubagentConfig]:
|
||||
def list_subagents(app_config: AppConfig) -> list[SubagentConfig]:
|
||||
"""List all available subagent configurations (with config.yaml overrides applied).
|
||||
|
||||
Returns:
|
||||
List of all registered SubagentConfig instances.
|
||||
"""
|
||||
return [get_subagent_config(name) for name in BUILTIN_SUBAGENTS]
|
||||
return [get_subagent_config(name, app_config) for name in BUILTIN_SUBAGENTS]
|
||||
|
||||
|
||||
def get_subagent_names() -> list[str]:
|
||||
@ -71,7 +62,7 @@ def get_subagent_names() -> list[str]:
|
||||
return list(BUILTIN_SUBAGENTS.keys())
|
||||
|
||||
|
||||
def get_available_subagent_names() -> list[str]:
|
||||
def get_available_subagent_names(app_config: AppConfig) -> list[str]:
|
||||
"""Get subagent names that should be exposed to the active runtime.
|
||||
|
||||
Returns:
|
||||
@ -79,7 +70,7 @@ def get_available_subagent_names() -> list[str]:
|
||||
"""
|
||||
names = list(BUILTIN_SUBAGENTS.keys())
|
||||
try:
|
||||
host_bash_allowed = is_host_bash_allowed()
|
||||
host_bash_allowed = is_host_bash_allowed(app_config)
|
||||
except Exception:
|
||||
logger.debug("Could not determine host bash availability; exposing all built-in subagents")
|
||||
return names
|
||||
|
||||
@ -60,20 +60,21 @@ async def task_tool(
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
|
||||
"""
|
||||
available_subagent_names = get_available_subagent_names()
|
||||
ctx = resolve_context(runtime)
|
||||
available_subagent_names = get_available_subagent_names(ctx.app_config)
|
||||
|
||||
# Get subagent configuration
|
||||
config = get_subagent_config(subagent_type)
|
||||
config = get_subagent_config(subagent_type, ctx.app_config)
|
||||
if config is None:
|
||||
available = ", ".join(available_subagent_names)
|
||||
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
|
||||
if subagent_type == "bash" and not is_host_bash_allowed(resolve_context(runtime).app_config):
|
||||
if subagent_type == "bash" and not is_host_bash_allowed(ctx.app_config):
|
||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||
|
||||
# Build config overrides
|
||||
overrides: dict = {}
|
||||
|
||||
skills_section = get_skills_prompt_section()
|
||||
skills_section = get_skills_prompt_section(ctx.app_config)
|
||||
if skills_section:
|
||||
overrides["system_prompt"] = config.system_prompt + "\n\n" + skills_section
|
||||
|
||||
@ -107,12 +108,13 @@ async def task_tool(
|
||||
from deerflow.tools import get_available_tools
|
||||
|
||||
# Subagents should not have subagent tools enabled (prevent recursive nesting)
|
||||
tools = get_available_tools(model_name=parent_model, subagent_enabled=False)
|
||||
tools = get_available_tools(model_name=parent_model, subagent_enabled=False, app_config=ctx.app_config)
|
||||
|
||||
# Create executor
|
||||
executor = SubagentExecutor(
|
||||
config=config,
|
||||
tools=tools,
|
||||
app_config=ctx.app_config,
|
||||
parent_model=parent_model,
|
||||
sandbox_state=sandbox_state,
|
||||
thread_data=thread_data,
|
||||
|
||||
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import shutil
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
@ -13,6 +13,9 @@ from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.mcp.tools import _make_sync_tool_wrapper
|
||||
from deerflow.skills.manager import (
|
||||
append_history,
|
||||
@ -60,8 +63,8 @@ def _history_record(*, action: str, file_path: str, prev_content: str | None, ne
|
||||
}
|
||||
|
||||
|
||||
async def _scan_or_raise(content: str, *, executable: bool, location: str) -> dict[str, str]:
|
||||
result = await scan_skill_content(content, executable=executable, location=location)
|
||||
async def _scan_or_raise(app_config: "AppConfig", content: str, *, executable: bool, location: str) -> dict[str, str]:
|
||||
result = await scan_skill_content(app_config, content, executable=executable, location=location)
|
||||
if result.decision == "block":
|
||||
raise ValueError(f"Security scan blocked the write: {result.reason}")
|
||||
if executable and result.decision != "allow":
|
||||
@ -94,50 +97,55 @@ async def _skill_manage_impl(
|
||||
replace: Replacement text for patch.
|
||||
expected_count: Optional expected number of replacements for patch.
|
||||
"""
|
||||
from deerflow.config.deer_flow_context import resolve_context
|
||||
|
||||
name = validate_skill_name(name)
|
||||
lock = _get_lock(name)
|
||||
thread_id = _get_thread_id(runtime)
|
||||
app_config = resolve_context(runtime).app_config
|
||||
|
||||
async with lock:
|
||||
if action == "create":
|
||||
if await _to_thread(custom_skill_exists, name):
|
||||
if await _to_thread(custom_skill_exists, name, app_config):
|
||||
raise ValueError(f"Custom skill '{name}' already exists.")
|
||||
if content is None:
|
||||
raise ValueError("content is required for create.")
|
||||
await _to_thread(validate_skill_markdown_content, name, content)
|
||||
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md")
|
||||
skill_file = await _to_thread(get_custom_skill_file, name)
|
||||
scan = await _scan_or_raise(app_config, content, executable=False, location=f"{name}/SKILL.md")
|
||||
skill_file = await _to_thread(get_custom_skill_file, name, app_config)
|
||||
await _to_thread(atomic_write, skill_file, content)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="create", file_path="SKILL.md", prev_content=None, new_content=content, thread_id=thread_id, scanner=scan),
|
||||
app_config,
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
await refresh_skills_system_prompt_cache_async(app_config)
|
||||
return f"Created custom skill '{name}'."
|
||||
|
||||
if action == "edit":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
|
||||
if content is None:
|
||||
raise ValueError("content is required for edit.")
|
||||
await _to_thread(validate_skill_markdown_content, name, content)
|
||||
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md")
|
||||
skill_file = await _to_thread(get_custom_skill_file, name)
|
||||
scan = await _scan_or_raise(app_config, content, executable=False, location=f"{name}/SKILL.md")
|
||||
skill_file = await _to_thread(get_custom_skill_file, name, app_config)
|
||||
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
|
||||
await _to_thread(atomic_write, skill_file, content)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="edit", file_path="SKILL.md", prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
|
||||
app_config,
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
await refresh_skills_system_prompt_cache_async(app_config)
|
||||
return f"Updated custom skill '{name}'."
|
||||
|
||||
if action == "patch":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
|
||||
if find is None or replace is None:
|
||||
raise ValueError("find and replace are required for patch.")
|
||||
skill_file = await _to_thread(get_custom_skill_file, name)
|
||||
skill_file = await _to_thread(get_custom_skill_file, name, app_config)
|
||||
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
|
||||
occurrences = prev_content.count(find)
|
||||
if occurrences == 0:
|
||||
@ -147,51 +155,54 @@ async def _skill_manage_impl(
|
||||
replacement_count = expected_count if expected_count is not None else 1
|
||||
new_content = prev_content.replace(find, replace, replacement_count)
|
||||
await _to_thread(validate_skill_markdown_content, name, new_content)
|
||||
scan = await _scan_or_raise(new_content, executable=False, location=f"{name}/SKILL.md")
|
||||
scan = await _scan_or_raise(app_config, new_content, executable=False, location=f"{name}/SKILL.md")
|
||||
await _to_thread(atomic_write, skill_file, new_content)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="patch", file_path="SKILL.md", prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan),
|
||||
app_config,
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
await refresh_skills_system_prompt_cache_async(app_config)
|
||||
return f"Patched custom skill '{name}' ({replacement_count} replacement(s) applied, {occurrences} match(es) found)."
|
||||
|
||||
if action == "delete":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
skill_dir = await _to_thread(get_custom_skill_dir, name)
|
||||
prev_content = await _to_thread(read_custom_skill_content, name)
|
||||
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
|
||||
skill_dir = await _to_thread(get_custom_skill_dir, name, app_config)
|
||||
prev_content = await _to_thread(read_custom_skill_content, name, app_config)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="delete", file_path="SKILL.md", prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
|
||||
app_config,
|
||||
)
|
||||
await _to_thread(shutil.rmtree, skill_dir)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
await refresh_skills_system_prompt_cache_async(app_config)
|
||||
return f"Deleted custom skill '{name}'."
|
||||
|
||||
if action == "write_file":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
|
||||
if path is None or content is None:
|
||||
raise ValueError("path and content are required for write_file.")
|
||||
target = await _to_thread(ensure_safe_support_path, name, path)
|
||||
target = await _to_thread(ensure_safe_support_path, name, path, app_config)
|
||||
exists = await _to_thread(target.exists)
|
||||
prev_content = await _to_thread(target.read_text, encoding="utf-8") if exists else None
|
||||
executable = "scripts/" in path or path.startswith("scripts/")
|
||||
scan = await _scan_or_raise(content, executable=executable, location=f"{name}/{path}")
|
||||
scan = await _scan_or_raise(app_config, content, executable=executable, location=f"{name}/{path}")
|
||||
await _to_thread(atomic_write, target, content)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="write_file", file_path=path, prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
|
||||
app_config,
|
||||
)
|
||||
return f"Wrote '{path}' for custom skill '{name}'."
|
||||
|
||||
if action == "remove_file":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
await _to_thread(ensure_custom_skill_is_editable, name, app_config)
|
||||
if path is None:
|
||||
raise ValueError("path is required for remove_file.")
|
||||
target = await _to_thread(ensure_safe_support_path, name, path)
|
||||
target = await _to_thread(ensure_safe_support_path, name, path, app_config)
|
||||
if not await _to_thread(target.exists):
|
||||
raise FileNotFoundError(f"Supporting file '{path}' not found for skill '{name}'.")
|
||||
prev_content = await _to_thread(target.read_text, encoding="utf-8")
|
||||
@ -200,10 +211,11 @@ async def _skill_manage_impl(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="remove_file", file_path=path, prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
|
||||
app_config,
|
||||
)
|
||||
return f"Removed '{path}' from custom skill '{name}'."
|
||||
|
||||
if await _to_thread(public_skill_exists, name):
|
||||
if await _to_thread(public_skill_exists, name, app_config):
|
||||
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
|
||||
raise ValueError(f"Unsupported action '{action}'.")
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ def get_available_tools(
|
||||
model_name: str | None = None,
|
||||
subagent_enabled: bool = False,
|
||||
*,
|
||||
app_config: AppConfig | None = None,
|
||||
app_config: AppConfig,
|
||||
) -> list[BaseTool]:
|
||||
"""Get all available tools from config.
|
||||
|
||||
@ -50,16 +50,11 @@ def get_available_tools(
|
||||
include_mcp: Whether to include tools from MCP servers (default: True).
|
||||
model_name: Optional model name to determine if vision tools should be included.
|
||||
subagent_enabled: Whether to include subagent tools (task, task_status).
|
||||
app_config: Explicit application config. Falls back to AppConfig.current()
|
||||
when omitted; new callers should pass this explicitly.
|
||||
app_config: Application config — required.
|
||||
|
||||
Returns:
|
||||
List of available tools.
|
||||
"""
|
||||
if app_config is None:
|
||||
# TODO(P2-10): fold into a required parameter once all callers thread
|
||||
# config explicitly (community tool factories, subagent registry, etc.).
|
||||
app_config = AppConfig.current()
|
||||
config = app_config
|
||||
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
|
||||
|
||||
|
||||
@ -133,7 +133,7 @@ def _do_convert(file_path: Path, pdf_converter: str) -> str:
|
||||
return _convert_with_markitdown(file_path)
|
||||
|
||||
|
||||
async def convert_file_to_markdown(file_path: Path) -> Path | None:
|
||||
async def convert_file_to_markdown(file_path: Path, app_config: object | None = None) -> Path | None:
|
||||
"""Convert a supported document file to Markdown.
|
||||
|
||||
PDF files are handled with a two-converter strategy (see module docstring).
|
||||
@ -142,12 +142,14 @@ async def convert_file_to_markdown(file_path: Path) -> Path | None:
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to convert.
|
||||
app_config: Optional AppConfig (for pdf_converter preference). When
|
||||
omitted, defaults to ``auto``.
|
||||
|
||||
Returns:
|
||||
Path to the generated .md file, or None if conversion failed.
|
||||
"""
|
||||
try:
|
||||
pdf_converter = _get_pdf_converter()
|
||||
pdf_converter = _get_pdf_converter(app_config)
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
if file_size > _ASYNC_THRESHOLD_BYTES:
|
||||
@ -286,24 +288,20 @@ def extract_outline(md_path: Path) -> list[dict]:
|
||||
return outline
|
||||
|
||||
|
||||
def _get_pdf_converter() -> str:
|
||||
def _get_pdf_converter(app_config: object | None) -> str:
|
||||
"""Read pdf_converter setting from app config, defaulting to 'auto'.
|
||||
|
||||
Normalizes the value to lowercase and validates it against the allowed set
|
||||
so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently
|
||||
fall through to unexpected behaviour.
|
||||
"""
|
||||
try:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
cfg = AppConfig.current()
|
||||
uploads_cfg = getattr(cfg, "uploads", None)
|
||||
if uploads_cfg is not None:
|
||||
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
|
||||
if raw not in _ALLOWED_PDF_CONVERTERS:
|
||||
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
|
||||
return "auto"
|
||||
return raw
|
||||
except Exception:
|
||||
pass
|
||||
return "auto"
|
||||
if app_config is None:
|
||||
return "auto"
|
||||
uploads_cfg = getattr(app_config, "uploads", None)
|
||||
if uploads_cfg is None:
|
||||
return "auto"
|
||||
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
|
||||
if raw not in _ALLOWED_PDF_CONVERTERS:
|
||||
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
|
||||
return "auto"
|
||||
return raw
|
||||
|
||||
@ -69,12 +69,18 @@ def provisioner_module():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_app_config():
|
||||
"""Initialize a minimal AppConfig for tests so ``AppConfig.current()`` never tries to auto-load config.yaml.
|
||||
def _auto_app_config_from_file(monkeypatch, request):
|
||||
"""Replace ``AppConfig.from_file`` with a minimal factory so tests that
|
||||
(directly or indirectly, e.g. via the LangGraph Server bootstrap path in
|
||||
``make_lead_agent``) load AppConfig from disk do not need a real
|
||||
``config.yaml`` on the filesystem.
|
||||
|
||||
Individual tests can still override via ``patch.object(AppConfig, "current", ...)``
|
||||
or by calling ``AppConfig.init()`` with a different config.
|
||||
Tests that want to verify the real ``from_file`` behaviour should mark
|
||||
themselves with ``@pytest.mark.real_from_file``.
|
||||
"""
|
||||
if request.node.get_closest_marker("real_from_file"):
|
||||
yield
|
||||
return
|
||||
try:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
@ -82,12 +88,11 @@ def _auto_app_config():
|
||||
yield
|
||||
return
|
||||
|
||||
previous_global = AppConfig._global
|
||||
AppConfig._global = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AppConfig._global = previous_global
|
||||
def _fake_from_file(config_path: str | None = None) -> AppConfig: # noqa: ARG001
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
monkeypatch.setattr(AppConfig, "from_file", _fake_from_file)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
@ -2,8 +2,11 @@
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
pytestmark = pytest.mark.real_from_file
|
||||
from pydantic import ValidationError
|
||||
|
||||
from deerflow.config.acp_config import ACPAgentConfig
|
||||
|
||||
@ -3,10 +3,13 @@ from __future__ import annotations
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
pytestmark = pytest.mark.real_from_file
|
||||
|
||||
|
||||
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
|
||||
path.write_text(
|
||||
@ -31,7 +34,11 @@ def _write_extensions_config(path: Path) -> None:
|
||||
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
|
||||
|
||||
|
||||
def test_init_then_get(tmp_path, monkeypatch):
|
||||
def test_from_file_reads_model_name(tmp_path, monkeypatch):
|
||||
"""``AppConfig.from_file`` is the only lifecycle method now; there is no
|
||||
process-global ``init/current``. Each consumer holds its own captured
|
||||
AppConfig instance.
|
||||
"""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
@ -41,14 +48,13 @@ def test_init_then_get(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
|
||||
config = AppConfig.from_file(str(config_path))
|
||||
AppConfig.init(config)
|
||||
|
||||
result = AppConfig.current()
|
||||
assert result is config
|
||||
assert result.models[0].name == "test-model"
|
||||
assert config.models[0].name == "test-model"
|
||||
|
||||
|
||||
def test_init_replaces_previous(tmp_path, monkeypatch):
|
||||
def test_from_file_each_call_returns_fresh_instance(tmp_path, monkeypatch):
|
||||
"""Two reads of the same file produce separate AppConfig instances —
|
||||
no hidden singleton, no memoization. Callers decide when to re-read.
|
||||
"""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
extensions_path = tmp_path / "extensions_config.json"
|
||||
_write_extensions_config(extensions_path)
|
||||
@ -58,14 +64,12 @@ def test_init_replaces_previous(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
|
||||
|
||||
config_a = AppConfig.from_file(str(config_path))
|
||||
AppConfig.init(config_a)
|
||||
assert AppConfig.current().models[0].name == "model-a"
|
||||
assert config_a.models[0].name == "model-a"
|
||||
|
||||
_write_config(config_path, model_name="model-b", supports_thinking=True)
|
||||
config_b = AppConfig.from_file(str(config_path))
|
||||
AppConfig.init(config_b)
|
||||
assert AppConfig.current().models[0].name == "model-b"
|
||||
assert AppConfig.current() is config_b
|
||||
assert config_b.models[0].name == "model-b"
|
||||
assert config_a is not config_b
|
||||
|
||||
|
||||
def test_config_version_check(tmp_path, monkeypatch):
|
||||
|
||||
@ -64,39 +64,42 @@ class TestGetCheckpointer:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=_make_config()):
|
||||
cp = get_checkpointer()
|
||||
cp = get_checkpointer(AppConfig.current())
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_returns_in_memory_saver_when_config_not_found(self):
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
def test_raises_when_config_file_missing(self):
|
||||
"""A missing config.yaml is a deployment error, not a cue to degrade to InMemorySaver.
|
||||
|
||||
Silent degradation would drop persistent-run state on process restart.
|
||||
`get_checkpointer` only falls back to InMemorySaver for the explicit
|
||||
`checkpointer: null` opt-in (test above), not for I/O failure.
|
||||
"""
|
||||
with patch.object(AppConfig, "current", side_effect=FileNotFoundError):
|
||||
cp = get_checkpointer()
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_checkpointer(AppConfig.current())
|
||||
|
||||
def test_memory_returns_in_memory_saver(self):
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
with patch.object(AppConfig, "current", return_value=cfg):
|
||||
cp = get_checkpointer()
|
||||
cp = get_checkpointer(AppConfig.current())
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_memory_singleton(self):
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
with patch.object(AppConfig, "current", return_value=cfg):
|
||||
cp1 = get_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
cp1 = get_checkpointer(AppConfig.current())
|
||||
cp2 = get_checkpointer(AppConfig.current())
|
||||
assert cp1 is cp2
|
||||
|
||||
def test_reset_clears_singleton(self):
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
with patch.object(AppConfig, "current", return_value=cfg):
|
||||
cp1 = get_checkpointer()
|
||||
cp1 = get_checkpointer(AppConfig.current())
|
||||
reset_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
cp2 = get_checkpointer(AppConfig.current())
|
||||
assert cp1 is not cp2
|
||||
|
||||
def test_sqlite_raises_when_package_missing(self):
|
||||
@ -107,7 +110,7 @@ class TestGetCheckpointer:
|
||||
):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(AppConfig.current())
|
||||
|
||||
def test_postgres_raises_when_package_missing(self):
|
||||
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db"))
|
||||
@ -117,7 +120,7 @@ class TestGetCheckpointer:
|
||||
):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(AppConfig.current())
|
||||
|
||||
def test_postgres_raises_when_connection_string_missing(self):
|
||||
cfg = _make_config(CheckpointerConfig(type="postgres"))
|
||||
@ -130,7 +133,7 @@ class TestGetCheckpointer:
|
||||
):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ValueError, match="connection_string is required"):
|
||||
get_checkpointer()
|
||||
get_checkpointer(AppConfig.current())
|
||||
|
||||
def test_sqlite_creates_saver(self):
|
||||
"""SQLite checkpointer is created when package is available."""
|
||||
@ -152,7 +155,7 @@ class TestGetCheckpointer:
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
|
||||
):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
cp = get_checkpointer(AppConfig.current())
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once()
|
||||
@ -178,7 +181,7 @@ class TestGetCheckpointer:
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}),
|
||||
):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
cp = get_checkpointer(AppConfig.current())
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
|
||||
@ -214,7 +217,7 @@ class TestAsyncCheckpointer:
|
||||
return_value="/tmp/resolved/test.db",
|
||||
),
|
||||
):
|
||||
async with make_checkpointer() as saver:
|
||||
async with make_checkpointer(AppConfig.current()) as saver:
|
||||
assert saver is mock_saver
|
||||
|
||||
mock_to_thread.assert_awaited_once()
|
||||
@ -245,7 +248,7 @@ class TestAppConfigLoadsCheckpointer:
|
||||
|
||||
class TestClientCheckpointerFallback:
|
||||
def test_client_uses_config_checkpointer_when_none_provided(self):
|
||||
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
|
||||
"""DeerFlowClient._ensure_agent falls back to get_checkpointer(AppConfig.current()) when checkpointer=None."""
|
||||
# This is a structural test — verifying the fallback path exists.
|
||||
cfg = _make_config(CheckpointerConfig(type="memory"))
|
||||
assert cfg.checkpointer is not None
|
||||
|
||||
@ -22,7 +22,7 @@ class TestCheckpointerNoneFix:
|
||||
mock_config.database = None
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=mock_config):
|
||||
async with make_checkpointer() as checkpointer:
|
||||
async with make_checkpointer(AppConfig.current()) as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
@ -45,7 +45,7 @@ class TestCheckpointerNoneFix:
|
||||
mock_config.checkpointer = None
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=mock_config):
|
||||
with checkpointer_context() as checkpointer:
|
||||
with checkpointer_context(AppConfig.current()) as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
|
||||
@ -44,9 +44,12 @@ def mock_app_config():
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_app_config):
|
||||
"""Create a DeerFlowClient with mocked config loading."""
|
||||
with patch.object(AppConfig, "current", return_value=mock_app_config):
|
||||
return DeerFlowClient()
|
||||
"""Create a DeerFlowClient holding the mocked config directly.
|
||||
|
||||
Passing ``config=`` is the documented post-refactor way to inject a
|
||||
test AppConfig; nothing relies on process-global state.
|
||||
"""
|
||||
return DeerFlowClient(config=mock_app_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -124,7 +127,7 @@ class TestConfigQueries:
|
||||
|
||||
with patch("deerflow.skills.loader.load_skills", return_value=[skill]) as mock_load:
|
||||
result = client.list_skills()
|
||||
mock_load.assert_called_once_with(enabled_only=False)
|
||||
mock_load.assert_called_once_with(client._app_config, enabled_only=False)
|
||||
|
||||
assert "skills" in result
|
||||
assert len(result["skills"]) == 1
|
||||
@ -139,7 +142,7 @@ class TestConfigQueries:
|
||||
def test_list_skills_enabled_only(self, client):
|
||||
with patch("deerflow.skills.loader.load_skills", return_value=[]) as mock_load:
|
||||
client.list_skills(enabled_only=True)
|
||||
mock_load.assert_called_once_with(enabled_only=True)
|
||||
mock_load.assert_called_once_with(client._app_config, enabled_only=True)
|
||||
|
||||
def test_get_memory(self, client):
|
||||
memory = {"version": "1.0", "facts": []}
|
||||
@ -1244,7 +1247,7 @@ class TestMemoryManagement:
|
||||
|
||||
assert mock_import.call_count == 1
|
||||
call_args = mock_import.call_args
|
||||
assert call_args.args == (imported,)
|
||||
assert call_args.args == (client._app_config.memory, imported)
|
||||
assert "user_id" in call_args.kwargs
|
||||
assert result == imported
|
||||
|
||||
@ -1269,6 +1272,7 @@ class TestMemoryManagement:
|
||||
confidence=0.88,
|
||||
)
|
||||
create_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
content="User prefers concise code reviews.",
|
||||
category="preference",
|
||||
confidence=0.88,
|
||||
@ -1279,7 +1283,7 @@ class TestMemoryManagement:
|
||||
data = {"version": "1.0", "facts": []}
|
||||
with patch("deerflow.agents.memory.updater.delete_memory_fact", return_value=data) as delete_fact:
|
||||
result = client.delete_memory_fact("fact_123")
|
||||
delete_fact.assert_called_once_with("fact_123")
|
||||
delete_fact.assert_called_once_with(client._app_config.memory, "fact_123")
|
||||
assert result == data
|
||||
|
||||
def test_update_memory_fact(self, client):
|
||||
@ -1292,6 +1296,7 @@ class TestMemoryManagement:
|
||||
confidence=0.91,
|
||||
)
|
||||
update_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
fact_id="fact_123",
|
||||
content="User prefers spaces",
|
||||
category="workflow",
|
||||
@ -1307,6 +1312,7 @@ class TestMemoryManagement:
|
||||
"User prefers spaces",
|
||||
)
|
||||
update_fact.assert_called_once_with(
|
||||
client._app_config.memory,
|
||||
fact_id="fact_123",
|
||||
content="User prefers spaces",
|
||||
category=None,
|
||||
@ -1802,7 +1808,7 @@ class TestScenarioConfigManagement:
|
||||
reloaded_config.mcp_servers = {"my-mcp": reloaded_server}
|
||||
|
||||
client._agent = MagicMock() # Simulate existing agent
|
||||
AppConfig.init(MagicMock(extensions=current_config))
|
||||
client._app_config = MagicMock(extensions=current_config)
|
||||
with (
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)),
|
||||
@ -2220,8 +2226,7 @@ class TestGatewayConformance:
|
||||
model.supports_thinking = False
|
||||
mock_app_config.models = [model]
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=mock_app_config):
|
||||
client = DeerFlowClient()
|
||||
client = DeerFlowClient(config=mock_app_config)
|
||||
|
||||
result = client.list_models()
|
||||
parsed = ModelsListResponse(**result)
|
||||
@ -2239,8 +2244,7 @@ class TestGatewayConformance:
|
||||
mock_app_config.models = [model]
|
||||
mock_app_config.get_model_config.return_value = model
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=mock_app_config):
|
||||
client = DeerFlowClient()
|
||||
client = DeerFlowClient(config=mock_app_config)
|
||||
|
||||
result = client.get_model("test-model")
|
||||
assert result is not None
|
||||
@ -3076,7 +3080,7 @@ class TestBugAgentInvalidationInconsistency:
|
||||
config_file = Path(tmp) / "ext.json"
|
||||
config_file.write_text("{}")
|
||||
|
||||
AppConfig.init(MagicMock(extensions=current_config))
|
||||
client._app_config = MagicMock(extensions=current_config)
|
||||
with (
|
||||
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
|
||||
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded)),
|
||||
|
||||
@ -71,3 +71,25 @@ def test_config_model_rejects_mutation(model_cls: type[BaseModel]):
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
setattr(instance, first_field, "MUTATED")
|
||||
|
||||
|
||||
def test_extensions_nested_dict_mutation_is_not_blocked_by_pydantic():
|
||||
"""Regression guard: Pydantic `frozen=True` does NOT deep-freeze container fields.
|
||||
|
||||
This test documents the trap — callers MUST compose a new dict and persist
|
||||
it + reload AppConfig instead of reaching into `extensions.skills[x]`.
|
||||
If you need the dict to be truly immutable, wrap with Mapping/frozendict.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
|
||||
|
||||
ext = ExtensionsConfig(mcp_servers={}, skills={"a": SkillStateConfig(enabled=True)})
|
||||
|
||||
# This is the pre-refactor anti-pattern: Pydantic lets it through because
|
||||
# the outer model is frozen but the inner dict is a plain builtin. No error.
|
||||
ext.skills["a"] = SkillStateConfig(enabled=False)
|
||||
ext.skills["b"] = SkillStateConfig(enabled=True)
|
||||
|
||||
# The test asserts the leak exists so a future "add deep-freeze" change
|
||||
# flips this expectation and forces call-site review.
|
||||
assert ext.skills["a"].enabled is False
|
||||
assert "b" in ext.skills
|
||||
|
||||
@ -10,6 +10,9 @@ import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = MemoryConfig()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@ -335,7 +338,7 @@ class TestMemoryFilePath:
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path(None)
|
||||
assert path == tmp_path / "memory.json"
|
||||
|
||||
@ -348,7 +351,7 @@ class TestMemoryFilePath:
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path("code-reviewer")
|
||||
assert path == tmp_path / "agents" / "code-reviewer" / "memory.json"
|
||||
|
||||
@ -360,7 +363,7 @@ class TestMemoryFilePath:
|
||||
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
|
||||
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))),
|
||||
):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path_global = storage._get_memory_file_path(None)
|
||||
path_a = storage._get_memory_file_path("agent-a")
|
||||
path_b = storage._get_memory_file_path("agent-b")
|
||||
|
||||
@ -47,40 +47,16 @@ class TestResolveContext:
|
||||
runtime.context = ctx
|
||||
assert resolve_context(runtime) is ctx
|
||||
|
||||
def test_fallback_from_configurable(self):
|
||||
"""LangGraph Server path: runtime.context is None → construct from ContextVar + configurable."""
|
||||
def test_raises_on_none_context(self):
|
||||
"""Without a typed DeerFlowContext, resolve_context refuses to guess."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = None
|
||||
config = _make_config()
|
||||
with (
|
||||
patch.object(AppConfig, "current", return_value=config),
|
||||
patch("langgraph.config.get_config", return_value={"configurable": {"thread_id": "t2", "agent_name": "ag"}}),
|
||||
):
|
||||
ctx = resolve_context(runtime)
|
||||
assert ctx.thread_id == "t2"
|
||||
assert ctx.agent_name == "ag"
|
||||
assert ctx.app_config is config
|
||||
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
|
||||
resolve_context(runtime)
|
||||
|
||||
def test_fallback_empty_configurable(self):
|
||||
"""LangGraph Server path with no thread_id in configurable → empty string."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = None
|
||||
config = _make_config()
|
||||
with (
|
||||
patch.object(AppConfig, "current", return_value=config),
|
||||
patch("langgraph.config.get_config", return_value={"configurable": {}}),
|
||||
):
|
||||
ctx = resolve_context(runtime)
|
||||
assert ctx.thread_id == ""
|
||||
assert ctx.agent_name is None
|
||||
|
||||
def test_fallback_from_dict_context(self):
|
||||
"""Legacy path: runtime.context is a dict → extract from dict directly."""
|
||||
def test_raises_on_dict_context(self):
|
||||
"""Legacy dict shape is no longer supported — we raise instead of lazily loading AppConfig."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "old-dict", "agent_name": "from-dict"}
|
||||
config = _make_config()
|
||||
with patch.object(AppConfig, "current", return_value=config):
|
||||
ctx = resolve_context(runtime)
|
||||
assert ctx.thread_id == "old-dict"
|
||||
assert ctx.agent_name == "from-dict"
|
||||
assert ctx.app_config is config
|
||||
with pytest.raises(RuntimeError, match="resolve_context: runtime.context is not a DeerFlowContext"):
|
||||
resolve_context(runtime)
|
||||
|
||||
@ -7,6 +7,31 @@ import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
|
||||
|
||||
def _runtime_with_config(config):
|
||||
"""Build a runtime carrying a custom (possibly mocked) app_config.
|
||||
|
||||
``DeerFlowContext`` is a frozen dataclass typed as ``AppConfig`` but
|
||||
dataclasses don't enforce the type at runtime — handing a Mock through
|
||||
lets tests exercise the tool's ``get_tool_config`` lookup without going
|
||||
via ``AppConfig.current``.
|
||||
"""
|
||||
ctx = _P2Ctx.__new__(_P2Ctx)
|
||||
object.__setattr__(ctx, "app_config", config)
|
||||
object.__setattr__(ctx, "thread_id", "test-thread")
|
||||
object.__setattr__(ctx, "agent_name", None)
|
||||
return _P2NS(context=ctx)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_config():
|
||||
@ -51,7 +76,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test query"})
|
||||
result = web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert len(parsed) == 2
|
||||
@ -69,30 +94,30 @@ class TestWebSearchTool:
|
||||
|
||||
def test_search_with_custom_config(self, mock_exa_client):
|
||||
"""Test search respects custom configuration values."""
|
||||
with patch.object(AppConfig, "current") as mock_config:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {
|
||||
"max_results": 10,
|
||||
"search_type": "neural",
|
||||
"contents_max_characters": 2000,
|
||||
"api_key": "test-key",
|
||||
}
|
||||
mock_config.return_value.get_tool_config.return_value = tool_config
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {
|
||||
"max_results": 10,
|
||||
"search_type": "neural",
|
||||
"contents_max_characters": 2000,
|
||||
"api_key": "test-key",
|
||||
}
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = tool_config
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_exa_client.search.return_value = mock_response
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = []
|
||||
mock_exa_client.search.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "neural search"})
|
||||
web_search_tool.func(query="neural search", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
mock_exa_client.search.assert_called_once_with(
|
||||
"neural search",
|
||||
type="neural",
|
||||
num_results=10,
|
||||
contents={"highlights": {"max_characters": 2000}},
|
||||
)
|
||||
mock_exa_client.search.assert_called_once_with(
|
||||
"neural search",
|
||||
type="neural",
|
||||
num_results=10,
|
||||
contents={"highlights": {"max_characters": 2000}},
|
||||
)
|
||||
|
||||
def test_search_with_no_highlights(self, mock_app_config, mock_exa_client):
|
||||
"""Test search handles results with no highlights."""
|
||||
@ -107,7 +132,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
result = web_search_tool.func(query="test", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed[0]["snippet"] == ""
|
||||
@ -120,7 +145,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "nothing"})
|
||||
result = web_search_tool.func(query="nothing", runtime=_P2_RUNTIME)
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed == []
|
||||
@ -131,7 +156,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "error"})
|
||||
result = web_search_tool.func(query="error", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: API rate limit exceeded"
|
||||
|
||||
@ -149,7 +174,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "# Fetched Page\n\nThis is the page content."
|
||||
mock_exa_client.get_contents.assert_called_once_with(
|
||||
@ -169,7 +194,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result.startswith("# Untitled\n\n")
|
||||
|
||||
@ -181,7 +206,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com/404"})
|
||||
result = web_fetch_tool.func(url="https://example.com/404", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: No results found"
|
||||
|
||||
@ -191,16 +216,44 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "Error: Connection timeout"
|
||||
|
||||
def test_fetch_reads_web_fetch_config(self, mock_exa_client):
|
||||
"""Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'."""
|
||||
with patch.object(AppConfig, "current") as mock_config:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
mock_config.return_value.get_tool_config.return_value = tool_config
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
mock_result.text = "Content."
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = [mock_result]
|
||||
mock_exa_client.get_contents.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
fake_config.get_tool_config.assert_any_call("web_fetch")
|
||||
|
||||
def test_fetch_uses_independent_api_key(self, mock_exa_client):
|
||||
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
|
||||
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
|
||||
mock_exa_cls.return_value = mock_exa_client
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
|
||||
def get_tool_config(name):
|
||||
if name == "web_fetch":
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
@ -211,37 +264,9 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
mock_config.return_value.get_tool_config.assert_any_call("web_fetch")
|
||||
|
||||
def test_fetch_uses_independent_api_key(self, mock_exa_client):
|
||||
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
|
||||
with patch.object(AppConfig, "current") as mock_config:
|
||||
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
|
||||
mock_exa_cls.return_value = mock_exa_client
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "exa-fetch-key"}
|
||||
|
||||
def get_tool_config(name):
|
||||
if name == "web_fetch":
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
mock_config.return_value.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.title = "Page"
|
||||
mock_result.text = "Content."
|
||||
mock_response = MagicMock()
|
||||
mock_response.results = [mock_result]
|
||||
mock_exa_client.get_contents.return_value = mock_response
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
|
||||
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
|
||||
mock_exa_cls.assert_called_once_with(api_key="exa-fetch-key")
|
||||
|
||||
def test_fetch_truncates_long_content(self, mock_app_config, mock_exa_client):
|
||||
"""Test fetch truncates content to 4096 characters."""
|
||||
@ -255,7 +280,7 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.exa.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
# "# Long Page\n\n" is 14 chars, content truncated to 4096
|
||||
content_after_header = result.split("\n\n", 1)[1]
|
||||
|
||||
@ -3,16 +3,31 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from types import SimpleNamespace as _P2NS
|
||||
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
|
||||
|
||||
def _runtime_with_config(config):
|
||||
ctx = _P2Ctx.__new__(_P2Ctx)
|
||||
object.__setattr__(ctx, "app_config", config)
|
||||
object.__setattr__(ctx, "thread_id", "test-thread")
|
||||
object.__setattr__(ctx, "agent_name", None)
|
||||
return _P2NS(context=ctx)
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
|
||||
@patch.object(AppConfig, "current")
|
||||
def test_search_uses_web_search_config(self, mock_get_app_config, mock_firecrawl_cls):
|
||||
def test_search_uses_web_search_config(self, mock_firecrawl_cls):
|
||||
search_config = MagicMock()
|
||||
search_config.model_extra = {"api_key": "firecrawl-search-key", "max_results": 7}
|
||||
mock_get_app_config.return_value.get_tool_config.return_value = search_config
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.return_value = search_config
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.web = [
|
||||
@ -22,7 +37,7 @@ class TestWebSearchTool:
|
||||
|
||||
from deerflow.community.firecrawl.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test query"})
|
||||
result = web_search_tool.func(query="test query", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
assert json.loads(result) == [
|
||||
{
|
||||
@ -31,15 +46,14 @@ class TestWebSearchTool:
|
||||
"snippet": "Snippet",
|
||||
}
|
||||
]
|
||||
mock_get_app_config.return_value.get_tool_config.assert_called_with("web_search")
|
||||
fake_config.get_tool_config.assert_called_with("web_search")
|
||||
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-search-key")
|
||||
mock_firecrawl_cls.return_value.search.assert_called_once_with("test query", limit=7)
|
||||
|
||||
|
||||
class TestWebFetchTool:
|
||||
@patch("deerflow.community.firecrawl.tools.FirecrawlApp")
|
||||
@patch.object(AppConfig, "current")
|
||||
def test_fetch_uses_web_fetch_config(self, mock_get_app_config, mock_firecrawl_cls):
|
||||
def test_fetch_uses_web_fetch_config(self, mock_firecrawl_cls):
|
||||
fetch_config = MagicMock()
|
||||
fetch_config.model_extra = {"api_key": "firecrawl-fetch-key"}
|
||||
|
||||
@ -48,7 +62,8 @@ class TestWebFetchTool:
|
||||
return fetch_config
|
||||
return None
|
||||
|
||||
mock_get_app_config.return_value.get_tool_config.side_effect = get_tool_config
|
||||
fake_config = MagicMock()
|
||||
fake_config.get_tool_config.side_effect = get_tool_config
|
||||
|
||||
mock_scrape_result = MagicMock()
|
||||
mock_scrape_result.markdown = "Fetched markdown"
|
||||
@ -57,10 +72,10 @@ class TestWebFetchTool:
|
||||
|
||||
from deerflow.community.firecrawl.tools import web_fetch_tool
|
||||
|
||||
result = web_fetch_tool.invoke({"url": "https://example.com"})
|
||||
result = web_fetch_tool.func(url="https://example.com", runtime=_runtime_with_config(fake_config))
|
||||
|
||||
assert result == "# Fetched Page\n\nFetched markdown"
|
||||
mock_get_app_config.return_value.get_tool_config.assert_any_call("web_fetch")
|
||||
fake_config.get_tool_config.assert_any_call("web_fetch")
|
||||
mock_firecrawl_cls.assert_called_once_with(api_key="firecrawl-fetch-key")
|
||||
mock_firecrawl_cls.return_value.scrape.assert_called_once_with(
|
||||
"https://example.com",
|
||||
|
||||
@ -7,6 +7,16 @@ from deerflow.community.infoquest import tools
|
||||
from deerflow.community.infoquest.infoquest_client import InfoQuestClient
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
class TestInfoQuestClient:
|
||||
def test_infoquest_client_initialization(self):
|
||||
@ -131,7 +141,7 @@ class TestInfoQuestClient:
|
||||
mock_client.web_search.return_value = json.dumps([])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.web_search_tool.run("test query")
|
||||
result = tools.web_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == json.dumps([])
|
||||
mock_get_client.assert_called_once()
|
||||
@ -144,7 +154,7 @@ class TestInfoQuestClient:
|
||||
mock_client.fetch.return_value = "<html><body>Test content</body></html>"
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.web_fetch_tool.run("https://example.com")
|
||||
result = tools.web_fetch_tool.func(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
|
||||
assert result == "# Untitled\n\nTest content"
|
||||
mock_get_client.assert_called_once()
|
||||
@ -162,7 +172,7 @@ class TestInfoQuestClient:
|
||||
]
|
||||
mock_get.return_value = mock_config
|
||||
|
||||
client = tools._get_infoquest_client()
|
||||
client = tools._get_infoquest_client(mock_config)
|
||||
|
||||
assert client.search_time_range == 24
|
||||
assert client.fetch_time == 10
|
||||
@ -322,7 +332,7 @@ class TestImageSearch:
|
||||
mock_client.image_search.return_value = json.dumps([{"image_url": "https://example.com/image1.jpg"}])
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = tools.image_search_tool.run({"query": "test query"})
|
||||
result = tools.image_search_tool.func(query="test query", runtime=_P2_RUNTIME)
|
||||
|
||||
# Check if result is a valid JSON string
|
||||
result_data = json.loads(result)
|
||||
@ -341,7 +351,7 @@ class TestImageSearch:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Pass all parameters as a dictionary (extra parameters will be ignored)
|
||||
tools.image_search_tool.run({"query": "sunset", "time_range": 30, "site": "unsplash.com", "image_size": "l"})
|
||||
tools.image_search_tool.func(query="sunset", runtime=_P2_RUNTIME)
|
||||
|
||||
mock_get_client.assert_called_once()
|
||||
# image_search_tool only passes query to client.image_search
|
||||
|
||||
@ -685,5 +685,5 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
|
||||
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
|
||||
)
|
||||
|
||||
tools = get_available_tools(include_mcp=True, subagent_enabled=False)
|
||||
tools = get_available_tools(include_mcp=True, subagent_enabled=False, app_config=AppConfig.current())
|
||||
assert "invoke_acp_agent" in [tool.name for tool in tools]
|
||||
|
||||
@ -11,6 +11,16 @@ from deerflow.community.jina_ai.jina_client import JinaClient
|
||||
from deerflow.community.jina_ai.tools import web_fetch_tool
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
# --- Phase 2 test helper: injected runtime for community tools ---
|
||||
from types import SimpleNamespace as _P2NS
|
||||
from deerflow.config.app_config import AppConfig as _P2AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _P2SandboxConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _P2Ctx
|
||||
_P2_APP_CONFIG = _P2AppConfig(sandbox=_P2SandboxConfig(use="test"))
|
||||
_P2_RUNTIME = _P2NS(context=_P2Ctx(app_config=_P2_APP_CONFIG, thread_id="test-thread"))
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jina_client():
|
||||
@ -157,7 +167,7 @@ async def test_web_fetch_tool_returns_error_on_crawl_failure(monkeypatch):
|
||||
mock_config.get_tool_config.return_value = None
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
|
||||
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
assert result.startswith("Error:")
|
||||
assert "429" in result
|
||||
|
||||
@ -173,6 +183,6 @@ async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
|
||||
mock_config.get_tool_config.return_value = None
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: mock_config))
|
||||
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
|
||||
result = await web_fetch_tool.ainvoke("https://example.com")
|
||||
result = await web_fetch_tool.coroutine(url="https://example.com", runtime=_P2_RUNTIME)
|
||||
assert "Hello world" in result
|
||||
assert not result.startswith("Error:")
|
||||
|
||||
@ -98,7 +98,8 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
"is_plan_mode": False,
|
||||
"subagent_enabled": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
assert captured["name"] == "safe-model"
|
||||
@ -122,7 +123,6 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
AppConfig.init(app_config)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda _ac: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
@ -137,7 +137,6 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("default", supports_thinking=False)])
|
||||
patched = app_config.model_copy(update={"summarization": SummarizationConfig(enabled=True, model_name="model-masswork")})
|
||||
AppConfig.init(patched)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: patched))
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -8,22 +8,19 @@ from deerflow.config.app_config import AppConfig
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
|
||||
def test_build_custom_mounts_section_returns_empty_when_no_mounts():
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
|
||||
assert prompt_module._build_custom_mounts_section() == ""
|
||||
assert prompt_module._build_custom_mounts_section(config) == ""
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
|
||||
def test_build_custom_mounts_section_lists_configured_mounts():
|
||||
mounts = [
|
||||
SimpleNamespace(container_path="/home/user/shared", read_only=False),
|
||||
SimpleNamespace(container_path="/mnt/reference", read_only=True),
|
||||
]
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
|
||||
section = prompt_module._build_custom_mounts_section()
|
||||
section = prompt_module._build_custom_mounts_section(config)
|
||||
|
||||
assert "**Custom Mounted Directories:**" in section
|
||||
assert "`/home/user/shared`" in section
|
||||
@ -37,15 +34,15 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
|
||||
config = SimpleNamespace(
|
||||
sandbox=SimpleNamespace(mounts=mounts),
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template()
|
||||
prompt = prompt_module.apply_prompt_template(config)
|
||||
|
||||
assert "`/home/user/shared`" in prompt
|
||||
assert "Custom Mounted Directories" in prompt
|
||||
@ -55,15 +52,15 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
|
||||
config = SimpleNamespace(
|
||||
sandbox=SimpleNamespace(mounts=[]),
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda *a, **k: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda app_config: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda app_config, agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template()
|
||||
prompt = prompt_module.apply_prompt_template(config)
|
||||
|
||||
assert "Treat `/mnt/user-data/workspace` as your default current working directory" in prompt
|
||||
assert "`hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`" in prompt
|
||||
@ -84,7 +81,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
|
||||
)
|
||||
|
||||
state = {"skills": [make_skill("first-skill")]}
|
||||
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
|
||||
monkeypatch.setattr(prompt_module, "load_skills", lambda *a, **kwargs: list(state["skills"]))
|
||||
prompt_module._reset_skills_system_prompt_cache_state()
|
||||
|
||||
try:
|
||||
@ -120,7 +117,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
def fake_load_skills(enabled_only=True):
|
||||
def fake_load_skills(*a, **kwargs):
|
||||
nonlocal active_loads, max_active_loads, call_count
|
||||
with lock:
|
||||
active_loads += 1
|
||||
@ -157,7 +154,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
|
||||
|
||||
def test_warm_enabled_skills_cache_logs_on_timeout(monkeypatch, caplog):
|
||||
event = threading.Event()
|
||||
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda: event)
|
||||
monkeypatch.setattr(prompt_module, "_ensure_enabled_skills_cache", lambda *a, **k: event)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
warmed = prompt_module.warm_enabled_skills_cache(timeout_seconds=0.01)
|
||||
|
||||
@ -20,27 +20,40 @@ def _make_skill(name: str) -> Skill:
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_SKILLS_CONFIG = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
|
||||
|
||||
def _evolution_enabled_config() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
)
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_empty_when_no_skills_match(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills={"non_existent_skill"})
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"non_existent_skill"})
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_empty_when_available_skills_empty(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=set())
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=set())
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_returns_skills(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills={"skill1"})
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills={"skill1"})
|
||||
assert "skill1" in result
|
||||
assert "skill2" not in result
|
||||
assert "[built-in]" in result
|
||||
@ -48,56 +61,41 @@ def test_get_skills_prompt_section_returns_skills(monkeypatch):
|
||||
|
||||
def test_get_skills_prompt_section_returns_all_when_available_skills_is_none(monkeypatch):
|
||||
skills = [_make_skill("skill1"), _make_skill("skill2")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_DEFAULT_SKILLS_CONFIG, available_skills=None)
|
||||
assert "skill1" in result
|
||||
assert "skill2" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_includes_self_evolution_rules(monkeypatch):
|
||||
skills = [_make_skill("skill1")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
monkeypatch.setattr(
|
||||
AppConfig, "current",
|
||||
staticmethod(lambda: SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
)),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
|
||||
assert "Skill Self-Evolution" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_includes_self_evolution_rules_without_skills(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(
|
||||
AppConfig, "current",
|
||||
staticmethod(lambda: SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
)),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: [])
|
||||
|
||||
result = get_skills_prompt_section(available_skills=None)
|
||||
result = get_skills_prompt_section(_evolution_enabled_config(), available_skills=None)
|
||||
assert "Skill Self-Evolution" in result
|
||||
|
||||
|
||||
def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeypatch):
|
||||
skills = [_make_skill("skill1")]
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: skills)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda *a, **k: skills)
|
||||
config = _evolution_enabled_config()
|
||||
|
||||
enabled_result = get_skills_prompt_section(available_skills=None)
|
||||
enabled_result = get_skills_prompt_section(config, available_skills=None)
|
||||
assert "Skill Self-Evolution" in enabled_result
|
||||
|
||||
config.skill_evolution.enabled = False
|
||||
disabled_result = get_skills_prompt_section(available_skills=None)
|
||||
disabled_config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
disabled_result = get_skills_prompt_section(disabled_config, available_skills=None)
|
||||
assert "Skill Self-Evolution" not in disabled_result
|
||||
|
||||
|
||||
@ -123,7 +121,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
|
||||
captured_skills = []
|
||||
|
||||
def mock_apply_prompt_template(**kwargs):
|
||||
def mock_apply_prompt_template(_app_config, *args, **kwargs):
|
||||
captured_skills.append(kwargs.get("available_skills"))
|
||||
return "mock_prompt"
|
||||
|
||||
@ -131,15 +129,15 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
|
||||
# Case 1: Empty skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=[]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] == set()
|
||||
|
||||
# Case 2: None skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] is None
|
||||
|
||||
# Case 3: Some skills list
|
||||
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
|
||||
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}, app_config=mock_app_config)
|
||||
assert captured_skills[-1] == {"skill1"}
|
||||
|
||||
@ -28,7 +28,7 @@ def test_get_available_tools_hides_bash_for_default_local_sandbox(monkeypatch):
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=AppConfig.current())]
|
||||
|
||||
assert "bash" not in names
|
||||
assert "ls" in names
|
||||
@ -41,7 +41,7 @@ def test_get_available_tools_keeps_bash_when_explicitly_enabled(monkeypatch):
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=AppConfig.current())]
|
||||
|
||||
assert "bash" in names
|
||||
assert "ls" in names
|
||||
@ -58,7 +58,7 @@ def test_get_available_tools_hides_renamed_host_bash_alias(monkeypatch):
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=AppConfig.current())]
|
||||
|
||||
assert "bash" not in names
|
||||
assert "shell" not in names
|
||||
@ -76,7 +76,7 @@ def test_get_available_tools_keeps_bash_for_aio_sandbox(monkeypatch):
|
||||
lambda use, _: SimpleNamespace(name="bash" if "bash_tool" in use else "ls"),
|
||||
)
|
||||
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False)]
|
||||
names = [tool.name for tool in get_available_tools(include_mcp=False, subagent_enabled=False, app_config=AppConfig.current())]
|
||||
|
||||
assert "bash" in names
|
||||
assert "ls" in names
|
||||
|
||||
@ -314,7 +314,7 @@ class TestLocalSandboxProviderMounts:
|
||||
)
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
|
||||
|
||||
@ -336,7 +336,7 @@ class TestLocalSandboxProviderMounts:
|
||||
)
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
@ -360,7 +360,7 @@ class TestLocalSandboxProviderMounts:
|
||||
)
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
@ -384,6 +384,6 @@ class TestLocalSandboxProviderMounts:
|
||||
)
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
provider = LocalSandboxProvider(app_config=config)
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
|
||||
|
||||
@ -6,12 +6,24 @@ from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
def _make_config(**memory_overrides) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), memory=MemoryConfig(**memory_overrides))
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
|
||||
with (
|
||||
patch.object(AppConfig, "current", return_value=_make_config(enabled=True)),
|
||||
@ -26,7 +38,7 @@ def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None:
|
||||
|
||||
|
||||
def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
@ -52,7 +64,7 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
|
||||
with (
|
||||
patch.object(AppConfig, "current", return_value=_make_config(enabled=True)),
|
||||
@ -67,7 +79,7 @@ def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> No
|
||||
|
||||
|
||||
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
|
||||
@ -1,3 +1,15 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for user_id propagation through memory queue."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -27,7 +39,7 @@ def test_conversation_context_user_id_default_none():
|
||||
|
||||
|
||||
def test_queue_add_stores_user_id():
|
||||
q = MemoryUpdateQueue()
|
||||
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
assert len(q._queue) == 1
|
||||
@ -36,7 +48,7 @@ def test_queue_add_stores_user_id():
|
||||
|
||||
|
||||
def test_queue_process_passes_user_id_to_updater():
|
||||
q = MemoryUpdateQueue()
|
||||
q = MemoryUpdateQueue(_TEST_APP_CONFIG)
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
|
||||
|
||||
@ -4,6 +4,18 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import memory
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
_TEST_APP_CONFIG = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
|
||||
def _make_app() -> FastAPI:
|
||||
"""Build a memory-router app pre-populated with a minimal AppConfig."""
|
||||
app = FastAPI()
|
||||
app.state.config = _TEST_APP_CONFIG
|
||||
app.include_router(memory.router)
|
||||
return app
|
||||
|
||||
|
||||
def _sample_memory(facts: list[dict] | None = None) -> dict:
|
||||
@ -25,8 +37,7 @@ def _sample_memory(facts: list[dict] | None = None) -> dict:
|
||||
|
||||
|
||||
def test_export_memory_route_returns_current_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
exported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -49,8 +60,7 @@ def test_export_memory_route_returns_current_memory() -> None:
|
||||
|
||||
|
||||
def test_import_memory_route_returns_imported_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
imported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -73,8 +83,7 @@ def test_import_memory_route_returns_imported_memory() -> None:
|
||||
|
||||
|
||||
def test_export_memory_route_preserves_source_error() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
exported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -98,8 +107,7 @@ def test_export_memory_route_preserves_source_error() -> None:
|
||||
|
||||
|
||||
def test_import_memory_route_preserves_source_error() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
imported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -123,8 +131,7 @@ def test_import_memory_route_preserves_source_error() -> None:
|
||||
|
||||
|
||||
def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.clear_memory_data", return_value=_sample_memory()):
|
||||
with TestClient(app) as client:
|
||||
@ -135,8 +142,7 @@ def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||
|
||||
|
||||
def test_create_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -166,8 +172,7 @@ def test_create_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_delete_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -190,8 +195,7 @@ def test_delete_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.delete_memory_fact", side_effect=KeyError("fact_missing")):
|
||||
with TestClient(app) as client:
|
||||
@ -202,8 +206,7 @@ def test_delete_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_updated_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -233,8 +236,7 @@ def test_update_memory_fact_route_returns_updated_memory() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
updated_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -269,8 +271,7 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=KeyError("fact_missing")):
|
||||
with TestClient(app) as client:
|
||||
@ -288,8 +289,7 @@ def test_update_memory_fact_route_returns_404_for_missing_fact() -> None:
|
||||
|
||||
|
||||
def test_update_memory_fact_route_returns_specific_error_for_invalid_confidence() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
app = _make_app()
|
||||
|
||||
with patch("app.gateway.routers.memory.update_memory_fact", side_effect=ValueError("confidence")):
|
||||
with TestClient(app) as client:
|
||||
|
||||
@ -1,3 +1,15 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for memory storage providers."""
|
||||
|
||||
import threading
|
||||
@ -60,7 +72,7 @@ class TestFileMemoryStorage:
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path(None)
|
||||
assert path == tmp_path / "memory.json"
|
||||
|
||||
@ -73,14 +85,14 @@ class TestFileMemoryStorage:
|
||||
return mock_paths
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
path = storage._get_memory_file_path("test-agent")
|
||||
assert path == tmp_path / "agents" / "test-agent" / "memory.json"
|
||||
|
||||
@pytest.mark.parametrize("invalid_name", ["", "../etc/passwd", "agent/name", "agent\\name", "agent name", "agent@123", "agent_name"])
|
||||
def test_validate_agent_name_invalid(self, invalid_name):
|
||||
"""Should raise ValueError for invalid agent names that don't match the pattern."""
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
with pytest.raises(ValueError, match="Invalid agent name|Agent name must be a non-empty string"):
|
||||
storage._validate_agent_name(invalid_name)
|
||||
|
||||
@ -94,7 +106,7 @@ class TestFileMemoryStorage:
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = storage.load()
|
||||
assert isinstance(memory, dict)
|
||||
assert memory["version"] == "1.0"
|
||||
@ -110,7 +122,7 @@ class TestFileMemoryStorage:
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]}
|
||||
result = storage.save(test_memory)
|
||||
assert result is True
|
||||
@ -129,7 +141,7 @@ class TestFileMemoryStorage:
|
||||
|
||||
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_path="")):
|
||||
storage = FileMemoryStorage()
|
||||
storage = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
# First load
|
||||
memory1 = storage.load()
|
||||
assert memory1["facts"][0]["content"] == "initial fact"
|
||||
@ -157,20 +169,20 @@ class TestGetMemoryStorage:
|
||||
def test_returns_file_memory_storage_by_default(self):
|
||||
"""Should return FileMemoryStorage by default."""
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
storage = get_memory_storage()
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_falls_back_to_file_memory_storage_on_error(self):
|
||||
"""Should fall back to FileMemoryStorage if configured storage fails to load."""
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="non.existent.StorageClass")):
|
||||
storage = get_memory_storage()
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_returns_singleton_instance(self):
|
||||
"""Should return the same instance on subsequent calls."""
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
storage1 = get_memory_storage()
|
||||
storage2 = get_memory_storage()
|
||||
storage1 = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
storage2 = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert storage1 is storage2
|
||||
|
||||
def test_get_memory_storage_thread_safety(self):
|
||||
@ -181,7 +193,7 @@ class TestGetMemoryStorage:
|
||||
# get_memory_storage is called concurrently from multiple threads while
|
||||
# AppConfig.get is patched once around thread creation. This verifies
|
||||
# that the singleton initialization remains thread-safe.
|
||||
results.append(get_memory_storage())
|
||||
results.append(get_memory_storage(_TEST_MEMORY_CONFIG))
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")):
|
||||
threads = [threading.Thread(target=get_storage) for _ in range(10)]
|
||||
@ -198,12 +210,12 @@ class TestGetMemoryStorage:
|
||||
"""Should fall back to FileMemoryStorage if the configured class is not actually a class."""
|
||||
# Using a built-in function instead of a class
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="os.path.join")):
|
||||
storage = get_memory_storage()
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
def test_get_memory_storage_non_subclass_fallback(self):
|
||||
"""Should fall back to FileMemoryStorage if the configured class is not a subclass of MemoryStorage."""
|
||||
# Using 'dict' as a class that is not a MemoryStorage subclass
|
||||
with patch.object(AppConfig, "current", return_value=_app_config(storage_class="builtins.dict")):
|
||||
storage = get_memory_storage()
|
||||
storage = get_memory_storage(_TEST_MEMORY_CONFIG)
|
||||
assert isinstance(storage, FileMemoryStorage)
|
||||
|
||||
@ -1,3 +1,15 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for per-user memory storage isolation."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
@ -21,7 +33,7 @@ def base_dir(tmp_path: Path) -> Path:
|
||||
|
||||
@pytest.fixture
|
||||
def storage() -> FileMemoryStorage:
|
||||
return FileMemoryStorage()
|
||||
return FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -57,7 +69,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
expected_path = base_dir / "users" / "alice" / "memory.json"
|
||||
@ -68,7 +80,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory_a = create_empty_memory()
|
||||
memory_a["user"]["workContext"]["summary"] = "A"
|
||||
s.save(memory_a, user_id="alice")
|
||||
@ -85,7 +97,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id=None)
|
||||
expected_path = base_dir / "memory.json"
|
||||
@ -97,7 +109,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
|
||||
legacy_mem = create_empty_memory()
|
||||
legacy_mem["user"]["workContext"]["summary"] = "legacy"
|
||||
@ -116,7 +128,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "agent scoped"
|
||||
s.save(memory, "test-agent", user_id="alice")
|
||||
@ -129,7 +141,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
# After save, cache should have tuple key
|
||||
@ -141,7 +153,7 @@ class TestUserIsolatedStorage:
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
s = FileMemoryStorage(_TEST_MEMORY_CONFIG)
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "initial"
|
||||
s.save(memory, user_id="alice")
|
||||
|
||||
@ -6,6 +6,17 @@ the in-memory LangGraph Store backend used when database.backend=memory.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
@ -15,6 +15,18 @@ from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
def _make_memory(facts: list[dict[str, object]] | None = None) -> dict[str, object]:
|
||||
return {
|
||||
"version": "1.0",
|
||||
@ -38,7 +50,7 @@ def _memory_config(**overrides: object) -> AppConfig:
|
||||
|
||||
|
||||
def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -65,18 +77,14 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None
|
||||
{"content": "User likes Python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch.object(AppConfig, "current",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == ["User likes Python"]
|
||||
assert all(fact["id"] != "fact_remove" for fact in result["facts"])
|
||||
|
||||
|
||||
def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@ -85,11 +93,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
|
||||
{"content": "User works on DeerFlow", "category": "context", "confidence": 0.87},
|
||||
],
|
||||
}
|
||||
|
||||
with patch.object(AppConfig, "current",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-42")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == [
|
||||
"User prefers dark mode",
|
||||
@ -100,7 +104,7 @@ def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -
|
||||
|
||||
|
||||
def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=2, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -128,11 +132,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
{"content": "User likes noisy logs", "category": "behavior", "confidence": 0.6},
|
||||
],
|
||||
}
|
||||
|
||||
with patch.object(AppConfig, "current",
|
||||
return_value=_memory_config(max_facts=2, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-9")
|
||||
|
||||
assert [fact["content"] for fact in result["facts"]] == [
|
||||
"User likes Python",
|
||||
@ -143,7 +143,7 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None:
|
||||
|
||||
|
||||
def test_apply_updates_preserves_source_error() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@ -155,18 +155,14 @@ def test_apply_updates_preserves_source_error() -> None:
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(AppConfig, "current",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
|
||||
assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start."
|
||||
assert result["facts"][0]["category"] == "correction"
|
||||
|
||||
|
||||
def test_apply_updates_ignores_empty_source_error() -> None:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory()
|
||||
update_data = {
|
||||
"newFacts": [
|
||||
@ -178,18 +174,14 @@ def test_apply_updates_ignores_empty_source_error() -> None:
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(AppConfig, "current",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction")
|
||||
|
||||
assert "sourceError" not in result["facts"][0]
|
||||
|
||||
|
||||
def test_clear_memory_data_resets_all_sections() -> None:
|
||||
with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True):
|
||||
result = clear_memory_data()
|
||||
result = clear_memory_data(_TEST_MEMORY_CONFIG)
|
||||
|
||||
assert result["version"] == "1.0"
|
||||
assert result["facts"] == []
|
||||
@ -223,7 +215,7 @@ def test_delete_memory_fact_removes_only_matching_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = delete_memory_fact("fact_delete")
|
||||
result = delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_delete")
|
||||
|
||||
assert [fact["id"] for fact in result["facts"]] == ["fact_keep"]
|
||||
|
||||
@ -233,7 +225,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = create_memory_fact(
|
||||
result = create_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
content=" User prefers concise code reviews. ",
|
||||
category="preference",
|
||||
confidence=0.88,
|
||||
@ -248,7 +240,7 @@ def test_create_memory_fact_appends_manual_fact() -> None:
|
||||
|
||||
def test_create_memory_fact_rejects_empty_content() -> None:
|
||||
try:
|
||||
create_memory_fact(content=" ")
|
||||
create_memory_fact(_TEST_MEMORY_CONFIG, content=" ")
|
||||
except ValueError as exc:
|
||||
assert exc.args == ("content",)
|
||||
else:
|
||||
@ -258,7 +250,7 @@ def test_create_memory_fact_rejects_empty_content() -> None:
|
||||
def test_create_memory_fact_rejects_invalid_confidence() -> None:
|
||||
for confidence in (-0.1, 1.1, float("nan"), float("inf"), float("-inf")):
|
||||
try:
|
||||
create_memory_fact(content="User likes tests", confidence=confidence)
|
||||
create_memory_fact(_TEST_MEMORY_CONFIG, content="User likes tests", confidence=confidence)
|
||||
except ValueError as exc:
|
||||
assert exc.args == ("confidence",)
|
||||
else:
|
||||
@ -268,7 +260,7 @@ def test_create_memory_fact_rejects_invalid_confidence() -> None:
|
||||
def test_delete_memory_fact_raises_for_unknown_id() -> None:
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
|
||||
try:
|
||||
delete_memory_fact("fact_missing")
|
||||
delete_memory_fact(_TEST_MEMORY_CONFIG, "fact_missing")
|
||||
except KeyError as exc:
|
||||
assert exc.args == ("fact_missing",)
|
||||
else:
|
||||
@ -293,7 +285,7 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
||||
mock_storage.load.return_value = imported_memory
|
||||
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
result = import_memory_data(imported_memory)
|
||||
result = import_memory_data(_TEST_MEMORY_CONFIG, imported_memory)
|
||||
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
|
||||
mock_storage.load.assert_called_once_with(None, user_id=None)
|
||||
@ -326,7 +318,7 @@ def test_update_memory_fact_updates_only_matching_fact() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = update_memory_fact(
|
||||
result = update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
category="workflow",
|
||||
@ -359,7 +351,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory),
|
||||
patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True),
|
||||
):
|
||||
result = update_memory_fact(
|
||||
result = update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
)
|
||||
@ -372,7 +364,7 @@ def test_update_memory_fact_preserves_omitted_fields() -> None:
|
||||
def test_update_memory_fact_raises_for_unknown_id() -> None:
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()):
|
||||
try:
|
||||
update_memory_fact(
|
||||
update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_missing",
|
||||
content="User prefers concise code reviews.",
|
||||
category="preference",
|
||||
@ -404,7 +396,7 @@ def test_update_memory_fact_rejects_invalid_confidence() -> None:
|
||||
return_value=current_memory,
|
||||
):
|
||||
try:
|
||||
update_memory_fact(
|
||||
update_memory_fact(_TEST_MEMORY_CONFIG,
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
confidence=confidence,
|
||||
@ -521,7 +513,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
return model
|
||||
|
||||
def test_string_response_parses(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
|
||||
with (
|
||||
@ -543,7 +535,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
|
||||
def test_list_content_response_parses(self):
|
||||
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
list_content = [{"type": "text", "text": valid_json}]
|
||||
|
||||
@ -565,7 +557,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
assert result is True
|
||||
|
||||
def test_correction_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
@ -590,7 +582,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
|
||||
def test_correction_hint_empty_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
@ -619,7 +611,7 @@ class TestFactDeduplicationCaseInsensitive:
|
||||
"""Tests that fact deduplication is case-insensitive."""
|
||||
|
||||
def test_duplicate_fact_different_case_not_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -639,18 +631,14 @@ class TestFactDeduplicationCaseInsensitive:
|
||||
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch.object(AppConfig, "current",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
# Should still have only 1 fact (duplicate rejected)
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["content"] == "User prefers Python"
|
||||
|
||||
def test_unique_fact_different_case_and_content_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_memory_config(max_facts=100, fact_confidence_threshold=0.7))
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
@ -669,11 +657,7 @@ class TestFactDeduplicationCaseInsensitive:
|
||||
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
|
||||
],
|
||||
}
|
||||
|
||||
with patch.object(AppConfig, "current",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert len(result["facts"]) == 2
|
||||
|
||||
@ -690,7 +674,7 @@ class TestReinforcementHint:
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
@ -715,7 +699,7 @@ class TestReinforcementHint:
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
@ -740,7 +724,7 @@ class TestReinforcementHint:
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(_TEST_APP_CONFIG)
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
|
||||
@ -1,3 +1,15 @@
|
||||
|
||||
# --- Phase 2 config-refactor test helper ---
|
||||
# Memory APIs now take MemoryConfig / AppConfig explicitly. Tests construct a
|
||||
# minimal config once and reuse it across call sites.
|
||||
from deerflow.config.app_config import AppConfig as _TestAppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig as _TestMemoryConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig as _TestSandboxConfig
|
||||
|
||||
_TEST_MEMORY_CONFIG = _TestMemoryConfig(enabled=True)
|
||||
_TEST_APP_CONFIG = _TestAppConfig(sandbox=_TestSandboxConfig(use="test"), memory=_TEST_MEMORY_CONFIG)
|
||||
# -------------------------------------------
|
||||
|
||||
"""Tests for user_id propagation in memory updater."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -8,7 +20,7 @@ def test_get_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.load.return_value = {"version": "1.0"}
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
get_memory_data(user_id="alice")
|
||||
get_memory_data(_TEST_MEMORY_CONFIG, user_id="alice")
|
||||
mock_storage.load.assert_called_once_with(None, user_id="alice")
|
||||
|
||||
|
||||
@ -16,7 +28,7 @@ def test_save_memory_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
_save_memory_to_file({"version": "1.0"}, user_id="bob")
|
||||
_save_memory_to_file(_TEST_MEMORY_CONFIG, {"version": "1.0"}, user_id="bob")
|
||||
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
|
||||
|
||||
|
||||
@ -24,6 +36,6 @@ def test_clear_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
clear_memory_data(user_id="charlie")
|
||||
clear_memory_data(_TEST_MEMORY_CONFIG, user_id="charlie")
|
||||
# Verify save was called with user_id
|
||||
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
|
||||
|
||||
@ -88,7 +88,7 @@ def test_uses_first_model_when_name_is_none(monkeypatch):
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name=None)
|
||||
factory_module.create_chat_model(name=None, app_config=AppConfig.current())
|
||||
|
||||
# resolve_class is called — if we reach here without ValueError, the correct model was used
|
||||
assert FakeChatModel.captured_kwargs.get("model") == "alpha"
|
||||
@ -100,7 +100,7 @@ def test_raises_when_model_not_found(monkeypatch):
|
||||
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: [])
|
||||
|
||||
with pytest.raises(ValueError, match="ghost-model"):
|
||||
factory_module.create_chat_model(name="ghost-model")
|
||||
factory_module.create_chat_model(name="ghost-model", app_config=AppConfig.current())
|
||||
|
||||
|
||||
def test_appends_all_tracing_callbacks(monkeypatch):
|
||||
@ -109,7 +109,7 @@ def test_appends_all_tracing_callbacks(monkeypatch):
|
||||
monkeypatch.setattr(factory_module, "build_tracing_callbacks", lambda: ["smith-callback", "langfuse-callback"])
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
model = factory_module.create_chat_model(name="alpha")
|
||||
model = factory_module.create_chat_model(name="alpha", app_config=AppConfig.current())
|
||||
|
||||
assert model.callbacks == ["smith-callback", "langfuse-callback"]
|
||||
|
||||
@ -127,7 +127,7 @@ def test_thinking_enabled_raises_when_not_supported_but_when_thinking_enabled_is
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support thinking"):
|
||||
factory_module.create_chat_model(name="no-think", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-think", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
|
||||
def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(monkeypatch):
|
||||
@ -138,7 +138,7 @@ def test_thinking_enabled_raises_for_empty_when_thinking_enabled_explicitly_set(
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support thinking"):
|
||||
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-think-empty", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
|
||||
def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
|
||||
@ -147,7 +147,7 @@ def test_thinking_enabled_merges_when_thinking_enabled_settings(monkeypatch):
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="thinker", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="thinker", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("temperature") == 1.0
|
||||
assert FakeChatModel.captured_kwargs.get("max_tokens") == 16000
|
||||
@ -183,7 +183,7 @@ def test_thinking_disabled_openai_gateway_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="openai-gw", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
assert captured.get("reasoning_effort") == "minimal"
|
||||
@ -216,7 +216,7 @@ def test_thinking_disabled_langchain_anthropic_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="anthropic-native", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
assert "extra_body" not in captured
|
||||
@ -238,7 +238,7 @@ def test_thinking_disabled_no_when_thinking_enabled_does_nothing(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="plain", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="plain", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert "extra_body" not in captured
|
||||
assert "thinking" not in captured
|
||||
@ -278,7 +278,7 @@ def test_when_thinking_disabled_takes_precedence_over_hardcoded_disable(monkeypa
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="custom-disable", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
# User overrode the hardcoded "minimal" with "low"
|
||||
@ -310,7 +310,7 @@ def test_when_thinking_disabled_not_used_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="wtd-ignored", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
# when_thinking_enabled should apply, NOT when_thinking_disabled
|
||||
assert captured.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
@ -339,7 +339,7 @@ def test_when_thinking_disabled_without_when_thinking_enabled_still_applies(monk
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="wtd-only", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
# when_thinking_disabled is now gated independently of has_thinking_settings
|
||||
assert captured.get("reasoning_effort") == "low"
|
||||
@ -370,7 +370,7 @@ def test_when_thinking_disabled_excluded_from_model_dump(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="no-leak-wtd", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
# when_thinking_disabled value must NOT appear as a raw key
|
||||
assert "when_thinking_disabled" not in captured
|
||||
@ -394,7 +394,7 @@ def test_reasoning_effort_cleared_when_not_supported(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-effort", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="no-effort", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
@ -422,7 +422,7 @@ def test_reasoning_effort_preserved_when_supported(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="effort-model", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="effort-model", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
# When supports_reasoning_effort=True, it should NOT be cleared to None
|
||||
# The disable path sets it to "minimal"; supports_reasoning_effort=True keeps it
|
||||
@ -458,7 +458,7 @@ def test_thinking_shortcut_enables_thinking_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="shortcut-model", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("thinking") == thinking_settings
|
||||
|
||||
@ -488,7 +488,7 @@ def test_thinking_shortcut_disables_thinking_when_thinking_disabled(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="shortcut-disable", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
assert "extra_body" not in captured
|
||||
@ -520,7 +520,7 @@ def test_thinking_shortcut_merges_with_when_thinking_enabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="merge-model", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="merge-model", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
# Both the thinking shortcut and when_thinking_enabled settings should be applied
|
||||
assert captured.get("thinking") == thinking_settings
|
||||
@ -552,7 +552,7 @@ def test_thinking_shortcut_not_leaked_into_model_when_disabled(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="no-leak", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="no-leak", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
# The disable path should have set thinking to disabled (not the raw enabled shortcut)
|
||||
assert captured.get("thinking") == {"type": "disabled"}
|
||||
@ -590,7 +590,7 @@ def test_openai_compatible_provider_passes_base_url(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="minimax-m2.5")
|
||||
factory_module.create_chat_model(name="minimax-m2.5", app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("model") == "MiniMax-M2.5"
|
||||
assert captured.get("base_url") == "https://api.minimax.io/v1"
|
||||
@ -638,11 +638,11 @@ def test_openai_compatible_provider_multiple_models(monkeypatch):
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
# Create first model
|
||||
factory_module.create_chat_model(name="minimax-m2.5")
|
||||
factory_module.create_chat_model(name="minimax-m2.5", app_config=AppConfig.current())
|
||||
assert captured.get("model") == "MiniMax-M2.5"
|
||||
|
||||
# Create second model
|
||||
factory_module.create_chat_model(name="minimax-m2.5-highspeed")
|
||||
factory_module.create_chat_model(name="minimax-m2.5-highspeed", app_config=AppConfig.current())
|
||||
assert captured.get("model") == "MiniMax-M2.5-highspeed"
|
||||
|
||||
|
||||
@ -670,7 +670,7 @@ def test_codex_provider_disables_reasoning_when_thinking_disabled(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "none"
|
||||
|
||||
@ -690,7 +690,7 @@ def test_codex_provider_preserves_explicit_reasoning_effort(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high")
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, reasoning_effort="high", app_config=AppConfig.current())
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "high"
|
||||
|
||||
@ -710,7 +710,7 @@ def test_codex_provider_defaults_reasoning_effort_to_medium(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
assert FakeChatModel.captured_kwargs.get("reasoning_effort") == "medium"
|
||||
|
||||
@ -731,7 +731,7 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
|
||||
monkeypatch.setattr(codex_provider_module, "CodexChatModel", FakeCodexChatModel)
|
||||
|
||||
FakeChatModel.captured_kwargs = {}
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True)
|
||||
factory_module.create_chat_model(name="codex", thinking_enabled=True, app_config=AppConfig.current())
|
||||
|
||||
assert "max_tokens" not in FakeChatModel.captured_kwargs
|
||||
|
||||
@ -757,7 +757,7 @@ def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("extra_body") == {"top_k": 20, "chat_template_kwargs": {"thinking": False}}
|
||||
assert captured.get("reasoning_effort") is None
|
||||
@ -784,7 +784,7 @@ def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("extra_body") == {
|
||||
"top_k": 20,
|
||||
@ -818,7 +818,7 @@ def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek")
|
||||
factory_module.create_chat_model(name="deepseek", app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("stream_usage") is True
|
||||
|
||||
@ -837,7 +837,7 @@ def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="claude")
|
||||
factory_module.create_chat_model(name="claude", app_config=AppConfig.current())
|
||||
|
||||
assert "stream_usage" not in captured
|
||||
|
||||
@ -867,7 +867,7 @@ def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek")
|
||||
factory_module.create_chat_model(name="deepseek", app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("stream_usage") is False
|
||||
|
||||
@ -897,7 +897,7 @@ def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="gpt-5-responses")
|
||||
factory_module.create_chat_model(name="gpt-5-responses", app_config=AppConfig.current())
|
||||
|
||||
assert captured.get("use_responses_api") is True
|
||||
assert captured.get("output_version") == "responses/v1"
|
||||
@ -938,7 +938,7 @@ def test_no_duplicate_kwarg_when_reasoning_effort_in_config_and_thinking_disable
|
||||
_patch_factory(monkeypatch, cfg, model_class=CapturingModel)
|
||||
|
||||
# Must not raise TypeError
|
||||
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False)
|
||||
factory_module.create_chat_model(name="doubao-model", thinking_enabled=False, app_config=AppConfig.current())
|
||||
|
||||
# kwargs (runtime) takes precedence: thinking-disabled path sets reasoning_effort=minimal
|
||||
assert captured.get("reasoning_effort") == "minimal"
|
||||
|
||||
@ -15,6 +15,10 @@ def _make_runtime(tmp_path):
|
||||
workspace.mkdir()
|
||||
uploads.mkdir()
|
||||
outputs.mkdir()
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
return SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
@ -24,7 +28,10 @@ def _make_runtime(tmp_path):
|
||||
"outputs_path": str(outputs),
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id="thread-1",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -35,6 +35,53 @@ _THREAD_DATA = {
|
||||
}
|
||||
|
||||
|
||||
def _make_app_config(
|
||||
*,
|
||||
skills_container_path: str = "/mnt/skills",
|
||||
skills_host_path: str | None = None,
|
||||
mounts=None,
|
||||
mcp_servers=None,
|
||||
tool_config_map=None,
|
||||
) -> SimpleNamespace:
|
||||
"""Build a lightweight AppConfig stand-in used by tests.
|
||||
|
||||
Only the attributes accessed by the helpers under test are populated;
|
||||
everything else is omitted to keep the fake minimal and explicit.
|
||||
"""
|
||||
skills_path = Path(skills_host_path) if skills_host_path is not None else None
|
||||
skills_cfg = SimpleNamespace(
|
||||
container_path=skills_container_path,
|
||||
get_skills_path=lambda: skills_path if skills_path is not None else Path("/nonexistent-skills-root-12345"),
|
||||
)
|
||||
sandbox_cfg = SimpleNamespace(mounts=list(mounts) if mounts else [], bash_output_max_chars=20000)
|
||||
extensions_cfg = SimpleNamespace(mcp_servers=dict(mcp_servers) if mcp_servers else {})
|
||||
tool_config_map = dict(tool_config_map or {})
|
||||
return SimpleNamespace(
|
||||
skills=skills_cfg,
|
||||
sandbox=sandbox_cfg,
|
||||
extensions=extensions_cfg,
|
||||
get_tool_config=lambda name: tool_config_map.get(name),
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_APP_CONFIG = _make_app_config()
|
||||
|
||||
|
||||
def _make_ctx(thread_id: str = "thread-1", *, app_config=_DEFAULT_APP_CONFIG, sandbox_key: str | None = None):
|
||||
"""Build a DeerFlowContext-like object with extra attributes allowed.
|
||||
|
||||
``resolve_context`` only checks ``isinstance(ctx, DeerFlowContext)``; for
|
||||
tests that need additional attributes (``sandbox_key``) we use a subclass
|
||||
created at runtime.
|
||||
"""
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext as _DFC
|
||||
|
||||
ctx = _DFC(app_config=app_config, thread_id=thread_id)
|
||||
if sandbox_key is not None:
|
||||
object.__setattr__(ctx, "sandbox_key", sandbox_key)
|
||||
return ctx
|
||||
|
||||
|
||||
# ---------- replace_virtual_path ----------
|
||||
|
||||
|
||||
@ -86,7 +133,7 @@ def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing
|
||||
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash on a virtual path inside a command must be preserved."""
|
||||
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
|
||||
|
||||
|
||||
@ -95,7 +142,7 @@ def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
|
||||
def test_mask_local_paths_in_output_hides_host_paths() -> None:
|
||||
output = "Created: /tmp/deer-flow/threads/t1/user-data/workspace/result.txt"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert "/tmp/deer-flow/threads/t1/user-data" not in masked
|
||||
assert "/mnt/user-data/workspace/result.txt" in masked
|
||||
@ -108,7 +155,7 @@ def test_mask_local_paths_in_output_hides_skills_host_paths() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
output = "Reading: /home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert "/home/user/deer-flow/skills" not in masked
|
||||
assert "/mnt/skills/public/bootstrap/SKILL.md" in masked
|
||||
@ -144,12 +191,12 @@ def test_reject_path_traversal_allows_normal_paths() -> None:
|
||||
|
||||
def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
|
||||
with pytest.raises(PermissionError, match="configured mount paths"):
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
|
||||
@ -159,42 +206,41 @@ def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -
|
||||
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
|
||||
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA)
|
||||
validate_local_tool_path(VIRTUAL_PATH_PREFIX, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_user_data_paths() -> None:
|
||||
# Should not raise — user-data paths are always allowed
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/uploads/doc.pdf", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/outputs/result.csv", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_user_data_write() -> None:
|
||||
# read_only=False (default) should still work for user-data paths
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_traversal_in_user_data() -> None:
|
||||
"""Path traversal via .. in user-data paths must be rejected."""
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_traversal_in_skills() -> None:
|
||||
"""Path traversal via .. in skills paths must be rejected."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/skills/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
@ -202,7 +248,7 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
from deerflow.sandbox.exceptions import SandboxRuntimeError
|
||||
|
||||
with pytest.raises(SandboxRuntimeError):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None)
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", None, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- _resolve_skills_path ----------
|
||||
@ -210,32 +256,26 @@ def test_validate_local_tool_path_rejects_none_thread_data() -> None:
|
||||
|
||||
def test_resolve_skills_path_resolves_correctly() -> None:
|
||||
"""Skills virtual path should resolve to host path."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
|
||||
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
|
||||
# Force get_skills_path().exists() to be True without touching the FS
|
||||
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
|
||||
resolved = _resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", cfg)
|
||||
assert resolved == "/home/user/deer-flow/skills/public/bootstrap/SKILL.md"
|
||||
|
||||
|
||||
def test_resolve_skills_path_resolves_root() -> None:
|
||||
"""Skills container root should resolve to host skills directory."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
resolved = _resolve_skills_path("/mnt/skills")
|
||||
assert resolved == "/home/user/deer-flow/skills"
|
||||
cfg = _make_app_config(skills_host_path="/home/user/deer-flow/skills")
|
||||
cfg.skills.get_skills_path = lambda: type("_P", (), {"exists": lambda self: True, "__str__": lambda self: "/home/user/deer-flow/skills"})()
|
||||
resolved = _resolve_skills_path("/mnt/skills", cfg)
|
||||
assert resolved == "/home/user/deer-flow/skills"
|
||||
|
||||
|
||||
def test_resolve_skills_path_raises_when_not_configured() -> None:
|
||||
"""Should raise FileNotFoundError when skills directory is not available."""
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=None),
|
||||
):
|
||||
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
|
||||
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md")
|
||||
# Default app config has no host path configured → _get_skills_host_path returns None
|
||||
with pytest.raises(FileNotFoundError, match="Skills directory not available"):
|
||||
_resolve_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- _resolve_and_validate_user_data_path ----------
|
||||
@ -250,7 +290,7 @@ def test_resolve_and_validate_user_data_path_resolves_correctly(tmp_path: Path)
|
||||
"uploads_path": str(tmp_path / "uploads"),
|
||||
"outputs_path": str(tmp_path / "outputs"),
|
||||
}
|
||||
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data)
|
||||
resolved = _resolve_and_validate_user_data_path("/mnt/user-data/workspace/hello.txt", thread_data, _DEFAULT_APP_CONFIG)
|
||||
assert resolved == str(workspace / "hello.txt")
|
||||
|
||||
|
||||
@ -265,7 +305,7 @@ def test_resolve_and_validate_user_data_path_blocks_traversal(tmp_path: Path) ->
|
||||
}
|
||||
# This path resolves outside the allowed roots
|
||||
with pytest.raises(PermissionError):
|
||||
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data)
|
||||
_resolve_and_validate_user_data_path("/mnt/user-data/workspace/../../../etc/passwd", thread_data, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
# ---------- replace_virtual_paths_in_command ----------
|
||||
@ -278,7 +318,7 @@ def test_replace_virtual_paths_in_command_replaces_skills_paths() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/deer-flow/skills"),
|
||||
):
|
||||
cmd = "cat /mnt/skills/public/bootstrap/SKILL.md"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/skills" not in result
|
||||
assert "/home/user/deer-flow/skills/public/bootstrap/SKILL.md" in result
|
||||
|
||||
@ -290,7 +330,7 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value="/home/user/skills"),
|
||||
):
|
||||
cmd = "cat /mnt/skills/public/SKILL.md > /mnt/user-data/workspace/out.txt"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/skills" not in result
|
||||
assert "/mnt/user-data" not in result
|
||||
assert "/home/user/skills/public/SKILL.md" in result
|
||||
@ -302,30 +342,27 @@ def test_replace_virtual_paths_in_command_replaces_both() -> None:
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
|
||||
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
|
||||
"""HTTP URLs must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
|
||||
validate_local_bash_command_paths(
|
||||
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> None:
|
||||
@ -333,8 +370,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_user_data() -> No
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/user-data/workspace/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
|
||||
@ -343,14 +379,13 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_skills() -> None:
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/skills/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) -> None:
|
||||
runtime = SimpleNamespace(
|
||||
state={"sandbox": {"sandbox_id": "local"}, "thread_data": _THREAD_DATA.copy()},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=_make_ctx("thread-1"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -372,33 +407,32 @@ def test_bash_tool_rejects_host_bash_when_local_sandbox_default(monkeypatch) ->
|
||||
|
||||
|
||||
def test_is_skills_path_recognises_default_prefix() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
assert _is_skills_path("/mnt/skills") is True
|
||||
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md") is True
|
||||
assert _is_skills_path("/mnt/skills-extra/foo") is False
|
||||
assert _is_skills_path("/mnt/user-data/workspace") is False
|
||||
assert _is_skills_path("/mnt/skills", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_skills_path("/mnt/skills/public/bootstrap/SKILL.md", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_skills_path("/mnt/skills-extra/foo", _DEFAULT_APP_CONFIG) is False
|
||||
assert _is_skills_path("/mnt/user-data/workspace", _DEFAULT_APP_CONFIG) is False
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_skills_read_only() -> None:
|
||||
"""read_file / ls should be able to access /mnt/skills paths."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
# Should not raise
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=True,
|
||||
)
|
||||
# Should not raise — default app config uses /mnt/skills as container path
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_skills_write() -> None:
|
||||
"""write_file / str_replace must NOT write to skills paths."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=False,
|
||||
)
|
||||
with pytest.raises(PermissionError, match="Write access to skills path is not allowed"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=False,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_skills_path() -> None:
|
||||
@ -406,8 +440,7 @@ def test_validate_local_bash_command_paths_allows_skills_path() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_urls() -> None:
|
||||
@ -415,40 +448,35 @@ def test_validate_local_bash_command_paths_allows_urls() -> None:
|
||||
# HTTPS URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl -X POST https://example.com/api/v1/risk/check",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# HTTP URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://localhost:8080/health",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# URLs with query strings
|
||||
validate_local_bash_command_paths(
|
||||
"curl https://api.example.com/v2/search?q=test",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# FTP URLs
|
||||
validate_local_bash_command_paths(
|
||||
"curl ftp://ftp.example.com/pub/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
# URL mixed with valid virtual path
|
||||
validate_local_bash_command_paths(
|
||||
"curl https://example.com/data -o /mnt/user-data/workspace/data.json",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls() -> None:
|
||||
"""file:// URLs should be treated as unsafe and blocked."""
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("curl file:///etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls_case_insensitive() -> None:
|
||||
"""file:// URL detection should be case-insensitive."""
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("curl FILE:///etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -> None:
|
||||
@ -456,35 +484,36 @@ def test_validate_local_bash_command_paths_blocks_file_urls_mixed_with_valid() -
|
||||
with pytest.raises(PermissionError):
|
||||
validate_local_bash_command_paths(
|
||||
"curl file:///etc/passwd -o /mnt/user-data/workspace/out.txt",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_still_blocks_other_paths() -> None:
|
||||
"""Paths outside virtual and system prefixes must still be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_skills_custom_container_path() -> None:
|
||||
"""Skills with a custom container_path in config should also work."""
|
||||
with patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/custom/skills"):
|
||||
# Should not raise
|
||||
custom_config = _make_app_config(skills_container_path="/custom/skills")
|
||||
# Should not raise
|
||||
validate_local_tool_path(
|
||||
"/custom/skills/public/my-skill/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
custom_config,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
# The default /mnt/skills should not match since container path is /custom/skills
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(
|
||||
"/custom/skills/public/my-skill/SKILL.md",
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
custom_config,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
# The default /mnt/skills should not match since container path is /custom/skills
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
validate_local_tool_path(
|
||||
"/mnt/skills/public/bootstrap/SKILL.md",
|
||||
_THREAD_DATA,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------- ACP workspace path tests ----------
|
||||
|
||||
@ -501,6 +530,7 @@ def test_validate_local_tool_path_allows_acp_workspace_read_only() -> None:
|
||||
validate_local_tool_path(
|
||||
"/mnt/acp-workspace/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
@ -511,6 +541,7 @@ def test_validate_local_tool_path_blocks_acp_workspace_write() -> None:
|
||||
validate_local_tool_path(
|
||||
"/mnt/acp-workspace/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
_DEFAULT_APP_CONFIG,
|
||||
read_only=False,
|
||||
)
|
||||
|
||||
@ -519,8 +550,7 @@ def test_validate_local_bash_command_paths_allows_acp_workspace() -> None:
|
||||
"""bash commands referencing /mnt/acp-workspace should be allowed."""
|
||||
validate_local_bash_command_paths(
|
||||
"cp /mnt/acp-workspace/hello_world.py /mnt/user-data/outputs/hello_world.py",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -> None:
|
||||
@ -528,8 +558,7 @@ def test_validate_local_bash_command_paths_blocks_traversal_in_acp_workspace() -
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths(
|
||||
"cat /mnt/acp-workspace/../../etc/passwd",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
_THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_resolve_acp_workspace_path_resolves_correctly(tmp_path: Path) -> None:
|
||||
@ -571,7 +600,7 @@ def test_replace_virtual_paths_in_command_replaces_acp_workspace() -> None:
|
||||
acp_host = "/home/user/.deer-flow/acp-workspace"
|
||||
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
|
||||
cmd = "cp /mnt/acp-workspace/hello.py /mnt/user-data/outputs/hello.py"
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
assert "/mnt/acp-workspace" not in result
|
||||
assert f"{acp_host}/hello.py" in result
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/outputs/hello.py" in result
|
||||
@ -582,7 +611,7 @@ def test_mask_local_paths_in_output_hides_acp_workspace_host_paths() -> None:
|
||||
acp_host = "/home/user/.deer-flow/acp-workspace"
|
||||
with patch("deerflow.sandbox.tools._get_acp_workspace_host_path", return_value=acp_host):
|
||||
output = f"Copied: {acp_host}/hello.py"
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA)
|
||||
masked = mask_local_paths_in_output(output, _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
assert acp_host not in masked
|
||||
assert "/mnt/acp-workspace/hello.py" in masked
|
||||
@ -622,7 +651,7 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
def _make_app_config(enabled: bool) -> AppConfig:
|
||||
def _mcp_app_config(enabled: bool) -> AppConfig:
|
||||
return AppConfig(
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
extensions=ExtensionsConfig(
|
||||
@ -636,19 +665,19 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
|
||||
),
|
||||
)
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=_make_app_config(True)):
|
||||
# Should not raise - MCP filesystem paths are allowed
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA)
|
||||
enabled_cfg = _mcp_app_config(True)
|
||||
# Should not raise - MCP filesystem paths are allowed
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, enabled_cfg)
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/subdir/file.txt", _THREAD_DATA, enabled_cfg)
|
||||
|
||||
# Path traversal should still be blocked
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA)
|
||||
# Path traversal should still be blocked
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/d/workspace/../../etc/passwd", _THREAD_DATA, enabled_cfg)
|
||||
|
||||
# Disabled servers should not expose paths
|
||||
with patch.object(AppConfig, "current", return_value=_make_app_config(False)):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
disabled_cfg = _mcp_app_config(False)
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA, disabled_cfg)
|
||||
|
||||
|
||||
# ---------- Custom mount path tests ----------
|
||||
@ -666,12 +695,12 @@ def _mock_custom_mounts():
|
||||
|
||||
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
assert _is_custom_mount_path("/mnt/code-read") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
|
||||
assert _is_custom_mount_path("/mnt/data") is True
|
||||
assert _is_custom_mount_path("/mnt/data/file.txt") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
|
||||
assert _is_custom_mount_path("/mnt/other") is False
|
||||
assert _is_custom_mount_path("/mnt/code-read", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/code-read/src/main.py", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/data", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/data/file.txt", _DEFAULT_APP_CONFIG) is True
|
||||
assert _is_custom_mount_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG) is False
|
||||
assert _is_custom_mount_path("/mnt/other", _DEFAULT_APP_CONFIG) is False
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
@ -682,7 +711,7 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
mount = _get_custom_mount_for_path("/mnt/code/file.py")
|
||||
mount = _get_custom_mount_for_path("/mnt/code/file.py", _DEFAULT_APP_CONFIG)
|
||||
assert mount is not None
|
||||
assert mount.container_path == "/mnt/code"
|
||||
|
||||
@ -690,90 +719,72 @@ def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
|
||||
"""read_file / ls should be able to access custom mount paths."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
|
||||
"""write_file / str_replace must NOT write to read-only custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
|
||||
"""write_file / str_replace should succeed on writable custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Path traversal via .. in custom mount paths must be rejected."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
|
||||
"""bash commands referencing custom mount paths should be allowed."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Bash commands with traversal in custom mount paths should be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
|
||||
"""Paths not matching any custom mount should still be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA, _DEFAULT_APP_CONFIG)
|
||||
|
||||
|
||||
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
|
||||
"""_get_custom_mounts should cache after first successful load."""
|
||||
# Clear any existing cache
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
# Use real directories so host_path.exists() filtering passes
|
||||
def test_get_custom_mounts_reads_from_app_config(tmp_path) -> None:
|
||||
"""_get_custom_mounts should read directly from the supplied AppConfig."""
|
||||
dir_a = tmp_path / "code-read"
|
||||
dir_a.mkdir()
|
||||
dir_b = tmp_path / "data"
|
||||
dir_b.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
|
||||
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 2
|
||||
|
||||
# After caching, should return cached value even without mock
|
||||
assert hasattr(_get_custom_mounts, "_cached")
|
||||
assert len(_get_custom_mounts()) == 2
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
cfg = _make_app_config(mounts=mounts)
|
||||
result = _get_custom_mounts(cfg)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
|
||||
def test_get_custom_mounts_filters_nonexistent_host_path(tmp_path) -> None:
|
||||
"""_get_custom_mounts should only return mounts whose host_path exists."""
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
existing_dir = tmp_path / "existing"
|
||||
existing_dir.mkdir()
|
||||
@ -782,22 +793,16 @@ def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path)
|
||||
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
|
||||
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch.object(AppConfig, "current", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 1
|
||||
assert result[0].container_path == "/mnt/existing"
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
cfg = _make_app_config(mounts=mounts)
|
||||
result = _get_custom_mounts(cfg)
|
||||
assert len(result) == 1
|
||||
assert result[0].container_path == "/mnt/existing"
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
|
||||
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
|
||||
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo", _DEFAULT_APP_CONFIG)
|
||||
assert mount is None
|
||||
|
||||
|
||||
@ -828,8 +833,8 @@ def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) ->
|
||||
|
||||
sandbox = SharedSandbox()
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
@ -904,14 +909,14 @@ def test_str_replace_parallel_updates_in_isolated_sandboxes_should_not_share_pat
|
||||
"sandbox-b": IsolatedSandbox("sandbox-b", shared_state),
|
||||
}
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1", "sandbox_key": "sandbox-a"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-2", "sandbox_key": "sandbox-b"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1", sandbox_key="sandbox-a"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-2", sandbox_key="sandbox-b"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.ensure_sandbox_initialized",
|
||||
lambda runtime: sandboxes[runtime.context["sandbox_key"]],
|
||||
lambda runtime: sandboxes[runtime.context.sandbox_key],
|
||||
)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
@ -971,8 +976,8 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
|
||||
|
||||
sandbox = SharedSandbox()
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
SimpleNamespace(state={}, context=_make_ctx("thread-1"), config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch):
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
|
||||
result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
|
||||
result = await scan_skill_content(config, "---\nname: demo-skill\ndescription: demo\n---\n", executable=False)
|
||||
|
||||
assert result.decision == "block"
|
||||
assert "manual review required" in result.reason
|
||||
|
||||
@ -11,9 +11,9 @@ from deerflow.config.sandbox_config import SandboxConfig
|
||||
skill_manage_module = importlib.import_module("deerflow.tools.skill_manage_tool")
|
||||
|
||||
|
||||
def _make_context(thread_id: str) -> DeerFlowContext:
|
||||
def _make_context(thread_id: str, app_config: object | None = None) -> DeerFlowContext:
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
app_config=app_config if app_config is not None else AppConfig(sandbox=SandboxConfig(use="test")),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
@ -37,13 +37,13 @@ def test_skill_manage_create_and_patch(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
|
||||
result = anyio.run(
|
||||
skill_manage_module.skill_manage_tool.coroutine,
|
||||
@ -78,13 +78,13 @@ def test_skill_manage_patch_replaces_single_occurrence_by_default(monkeypatch, t
|
||||
)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
content = _skill_content("demo-skill", "Demo skill") + "\nRepeated: Demo skill\n"
|
||||
|
||||
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", content)
|
||||
@ -116,7 +116,7 @@ def test_skill_manage_rejects_public_skill_patch(monkeypatch, tmp_path):
|
||||
)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
|
||||
runtime = SimpleNamespace(context=_make_context(""), config={"configurable": {}})
|
||||
runtime = SimpleNamespace(context=_make_context("", config), config={"configurable": {}})
|
||||
|
||||
with pytest.raises(ValueError, match="built-in skill"):
|
||||
anyio.run(
|
||||
@ -140,13 +140,13 @@ def test_skill_manage_sync_wrapper_supported(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context=_make_context("thread-sync"), config={"configurable": {"thread_id": "thread-sync"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-sync", config), config={"configurable": {"thread_id": "thread-sync"}})
|
||||
result = skill_manage_module.skill_manage_tool.func(
|
||||
runtime=runtime,
|
||||
action="create",
|
||||
@ -166,13 +166,13 @@ def test_skill_manage_rejects_support_path_traversal(monkeypatch, tmp_path):
|
||||
)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(skill_manage_module, "refresh_skills_system_prompt_cache_async", _refresh)
|
||||
monkeypatch.setattr(skill_manage_module, "scan_skill_content", lambda *args, **kwargs: _async_result("allow", "ok"))
|
||||
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1"), config={"configurable": {"thread_id": "thread-1"}})
|
||||
runtime = SimpleNamespace(context=_make_context("thread-1", config), config={"configurable": {"thread_id": "thread-1"}})
|
||||
anyio.run(skill_manage_module.skill_manage_tool.coroutine, runtime, "create", "demo-skill", _skill_content("demo-skill"))
|
||||
|
||||
with pytest.raises(ValueError, match="parent-directory traversal|selected support directory"):
|
||||
|
||||
@ -50,12 +50,13 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@ -96,12 +97,12 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
|
||||
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
|
||||
)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: config))
|
||||
get_skill_history_file("demo-skill").write_text(
|
||||
get_skill_history_file("demo-skill", config).write_text(
|
||||
'{"action":"human_edit","prev_content":' + json.dumps(original_content) + ',"new_content":' + json.dumps(edited_content) + "}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
@ -114,6 +115,7 @@ def test_custom_skill_rollback_blocked_by_scanner(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", _scan)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@ -140,12 +142,13 @@ def test_custom_skill_delete_preserves_history_and_allows_restore(monkeypatch, t
|
||||
monkeypatch.setattr("app.gateway.routers.skills.scan_skill_content", lambda *args, **kwargs: _async_scan("allow", "ok"))
|
||||
refresh_calls = []
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
@ -169,13 +172,13 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
|
||||
enabled_state = {"value": True}
|
||||
refresh_calls = []
|
||||
|
||||
def _load_skills(*, enabled_only: bool):
|
||||
def _load_skills(*a, enabled_only: bool = False, **k):
|
||||
skill = _make_skill("demo-skill", enabled=enabled_state["value"])
|
||||
if enabled_only and not skill.enabled:
|
||||
return []
|
||||
return [skill]
|
||||
|
||||
async def _refresh():
|
||||
async def _refresh(*a, **k):
|
||||
refresh_calls.append("refresh")
|
||||
enabled_state["value"] = False
|
||||
|
||||
@ -183,7 +186,6 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
|
||||
|
||||
monkeypatch.setattr("app.gateway.routers.skills.load_skills", _load_skills)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: _app_cfg))
|
||||
monkeypatch.setattr(AppConfig, "init", staticmethod(lambda _cfg: None))
|
||||
monkeypatch.setattr(AppConfig, "from_file", staticmethod(lambda: _app_cfg))
|
||||
monkeypatch.setattr(skills_router.ExtensionsConfig, "resolve_config_path", staticmethod(lambda: config_path))
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
@ -27,7 +27,7 @@ def test_load_skills_discovers_nested_skills_and_sets_container_paths(tmp_path:
|
||||
_write_skill(skills_root / "public" / "parent" / "child-skill", "child-skill", "Child skill")
|
||||
_write_skill(skills_root / "custom" / "team" / "helper", "team-helper", "Team helper")
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
by_name = {skill.name: skill for skill in skills}
|
||||
|
||||
assert {"root-skill", "child-skill", "team-helper"} <= set(by_name)
|
||||
@ -57,7 +57,7 @@ def test_load_skills_skips_hidden_directories(tmp_path: Path):
|
||||
"Hidden skill",
|
||||
)
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
names = {skill.name for skill in skills}
|
||||
|
||||
assert "ok-skill" in names
|
||||
@ -69,7 +69,7 @@ def test_load_skills_prefers_custom_over_public_with_same_name(tmp_path: Path):
|
||||
_write_skill(skills_root / "public" / "shared-skill", "shared-skill", "Public version")
|
||||
_write_skill(skills_root / "custom" / "shared-skill", "shared-skill", "Custom version")
|
||||
|
||||
skills = load_skills(skills_path=skills_root, use_config=False, enabled_only=False)
|
||||
skills = load_skills(skills_path=skills_root, enabled_only=False)
|
||||
shared = next(skill for skill in skills if skill.name == "shared-skill")
|
||||
|
||||
assert shared.category == "custom"
|
||||
|
||||
@ -6,6 +6,7 @@ import re
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -331,6 +332,9 @@ async def test_concurrent_tasks_end_sentinel():
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_make_stream_bridge_defaults():
|
||||
"""make_stream_bridge() with no config yields a MemoryStreamBridge."""
|
||||
async with make_stream_bridge() as bridge:
|
||||
"""make_stream_bridge with a config lacking stream_bridge yields a MemoryStreamBridge."""
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
async with make_stream_bridge(config) as bridge:
|
||||
assert isinstance(bridge, MemoryStreamBridge)
|
||||
|
||||
@ -21,6 +21,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
_TEST_APP_CONFIG = MagicMock(name="TestAppConfig")
|
||||
|
||||
# Module names that need to be mocked to break circular imports
|
||||
_MOCKED_MODULE_NAMES = [
|
||||
"deerflow.agents",
|
||||
@ -203,7 +205,7 @@ class TestAsyncExecutionPath:
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
trace_id="test-trace",
|
||||
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -232,7 +234,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -259,7 +261,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -285,7 +287,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -306,7 +308,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -327,7 +329,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -348,7 +350,7 @@ class TestAsyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -384,7 +386,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -419,7 +421,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -456,7 +458,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -477,7 +479,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_aexecute") as mock_aexecute:
|
||||
@ -511,7 +513,7 @@ class TestSyncExecutionPath:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -565,7 +567,7 @@ class TestAsyncToolSupport:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -602,7 +604,7 @@ class TestAsyncToolSupport:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -648,7 +650,7 @@ class TestThreadSafety:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id=f"thread-{task_id}",
|
||||
thread_id=f"thread-{task_id}", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -858,7 +860,7 @@ class TestCooperativeCancellation:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -898,7 +900,7 @@ class TestCooperativeCancellation:
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
thread_id="test-thread", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||
@ -977,7 +979,7 @@ class TestCooperativeCancellation:
|
||||
config=short_config,
|
||||
tools=[],
|
||||
thread_id="test-thread",
|
||||
trace_id="test-trace",
|
||||
trace_id="test-trace", app_config=_TEST_APP_CONFIG,
|
||||
)
|
||||
|
||||
# Wrap _scheduler_pool.submit so we know when run_task finishes
|
||||
|
||||
@ -1,29 +1,35 @@
|
||||
"""Tests for subagent availability and prompt exposure under local bash hardening."""
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.subagents import registry as registry_module
|
||||
|
||||
|
||||
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: False)
|
||||
def _config() -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
|
||||
names = registry_module.get_available_subagent_names()
|
||||
|
||||
def test_get_available_subagent_names_hides_bash_when_host_bash_disabled(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
names = registry_module.get_available_subagent_names(_config())
|
||||
|
||||
assert names == ["general-purpose"]
|
||||
|
||||
|
||||
def test_get_available_subagent_names_keeps_bash_when_allowed(monkeypatch) -> None:
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda: True)
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", lambda *a, **k: True)
|
||||
|
||||
names = registry_module.get_available_subagent_names()
|
||||
names = registry_module.get_available_subagent_names(_config())
|
||||
|
||||
assert names == ["general-purpose", "bash"]
|
||||
|
||||
|
||||
def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch) -> None:
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose"])
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
|
||||
|
||||
section = prompt_module._build_subagent_section(3)
|
||||
section = prompt_module._build_subagent_section(3, _config())
|
||||
|
||||
assert "Not available in the current sandbox configuration" in section
|
||||
assert 'bash("npm test")' not in section
|
||||
@ -32,9 +38,9 @@ def test_build_subagent_section_hides_bash_examples_when_unavailable(monkeypatch
|
||||
|
||||
|
||||
def test_build_subagent_section_includes_bash_when_available(monkeypatch) -> None:
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda: ["general-purpose", "bash"])
|
||||
monkeypatch.setattr(prompt_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose", "bash"])
|
||||
|
||||
section = prompt_module._build_subagent_section(3)
|
||||
section = prompt_module._build_subagent_section(3, _config())
|
||||
|
||||
assert "For command execution (git, build, test, deploy operations)" in section
|
||||
assert 'bash("npm test")' in section
|
||||
|
||||
@ -93,11 +93,11 @@ class _DummyScheduledTask:
|
||||
|
||||
|
||||
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda: ["general-purpose"])
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: None)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *a, **k: ["general-purpose"])
|
||||
|
||||
result = _run_task_tool(
|
||||
runtime=None,
|
||||
runtime=_make_runtime(),
|
||||
description="执行任务",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
@ -108,7 +108,7 @@ def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||
|
||||
|
||||
def test_task_tool_rejects_bash_subagent_when_host_bash_disabled(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: _make_subagent_config())
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: _make_subagent_config())
|
||||
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", lambda *a, **k: False)
|
||||
|
||||
result = _run_task_tool(
|
||||
@ -152,9 +152,9 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "Skills Appendix")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
# task_tool lazily imports from deerflow.tools at call time, so patch that module-level function.
|
||||
@ -177,7 +177,9 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
assert captured["executor_kwargs"]["config"].max_turns == 7
|
||||
assert "Skills Appendix" in captured["executor_kwargs"]["config"].system_prompt
|
||||
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", subagent_enabled=False)
|
||||
from unittest.mock import ANY
|
||||
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", subagent_enabled=False, app_config=ANY)
|
||||
|
||||
event_types = [e["type"] for e in events]
|
||||
assert event_types == ["task_started", "task_running", "task_running", "task_completed"]
|
||||
@ -194,12 +196,12 @@ def test_task_tool_returns_failed_message(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@ -228,12 +230,12 @@ def test_task_tool_returns_timed_out_message(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@ -264,12 +266,12 @@ def test_task_tool_polling_safety_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@ -300,12 +302,12 @@ def test_cleanup_called_on_completed(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@ -340,12 +342,12 @@ def test_cleanup_called_on_failed(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.FAILED, error="error"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@ -380,12 +382,12 @@ def test_cleanup_called_on_timed_out(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@ -427,12 +429,12 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
@ -480,8 +482,8 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@ -531,12 +533,12 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@ -586,12 +588,12 @@ def test_cancellation_calls_request_cancel(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
lambda *a, **k: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||
@ -644,9 +646,9 @@ def test_task_tool_returns_cancelled_message(monkeypatch):
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda *a, **k: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda _cfg: "")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda *a, **k: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
|
||||
@ -77,11 +77,13 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="好的,先确认需求"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=12)))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=12))))
|
||||
title = result["title"]
|
||||
|
||||
assert title == "短标题"
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||
from unittest.mock import ANY
|
||||
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False, app_config=ANY)
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||
@ -97,7 +99,7 @@ class TestTitleMiddlewareCoreLogic:
|
||||
]
|
||||
}
|
||||
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
|
||||
title = result["title"]
|
||||
|
||||
assert title == "请帮我总结这段代码"
|
||||
@ -114,7 +116,7 @@ class TestTitleMiddlewareCoreLogic:
|
||||
AIMessage(content="收到"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, AppConfig(sandbox=SandboxConfig(use="test"), title=_make_title_config(max_chars=20))))
|
||||
title = result["title"]
|
||||
|
||||
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
||||
|
||||
@ -270,27 +270,27 @@ class TestDeferredToolsPromptSection:
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
|
||||
# tool_search.enabled defaults to False
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(AppConfig.current())
|
||||
assert section == ""
|
||||
|
||||
def test_empty_when_enabled_but_no_registry(self, monkeypatch):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(AppConfig.current())
|
||||
assert section == ""
|
||||
|
||||
def test_empty_when_enabled_but_empty_registry(self, monkeypatch):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
|
||||
set_deferred_registry(DeferredToolRegistry())
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(AppConfig.current())
|
||||
assert section == ""
|
||||
|
||||
def test_lists_tool_names(self, registry, monkeypatch):
|
||||
from deerflow.agents.lead_agent.prompt import get_deferred_tools_prompt_section
|
||||
AppConfig.current().tool_search = ToolSearchConfig(enabled=True)
|
||||
set_deferred_registry(registry)
|
||||
section = get_deferred_tools_prompt_section()
|
||||
section = get_deferred_tools_prompt_section(AppConfig.current())
|
||||
assert "<available-deferred-tools>" in section
|
||||
assert "</available-deferred-tools>" in section
|
||||
assert "github_create_issue" in section
|
||||
|
||||
5129
backend/uv.lock
generated
5129
backend/uv.lock
generated
File diff suppressed because it is too large
Load Diff
301
docs/CONFIG_DESIGN.zh.md
Normal file
301
docs/CONFIG_DESIGN.zh.md
Normal file
@ -0,0 +1,301 @@
|
||||
# DeerFlow 配置系统设计
|
||||
|
||||
> 对应实现:[PR #2271](https://github.com/bytedance/deer-flow/pull/2271) · RFC [#1811](https://github.com/bytedance/deer-flow/issues/1811) · 归档 spec:[config-refactor-design](./plans/2026-04-12-config-refactor-design.md)
|
||||
|
||||
## 1. 为什么要重构
|
||||
|
||||
重构前的 `deerflow/config/` 有三个结构性问题,凑在一起就是"全局可变状态 + 副作用耦合"的经典反模式:
|
||||
|
||||
| 问题 | 具体表现 |
|
||||
|------|----------|
|
||||
| 双重真相 | 每个 sub-config 同时是 `AppConfig` 字段**和**模块级全局(`_memory_config` / `_title_config` …)。consumer 不知道该信哪个 |
|
||||
| 副作用耦合 | `AppConfig.from_file()` 顺便 mutate 8 个 sub-module 的 globals(通过 `load_*_from_dict()`) |
|
||||
| 隔离不完整 | 原有的 `ContextVar` 只罩住 `AppConfig` 本体,8 个 sub-config globals 漏在外面 |
|
||||
|
||||
从类型论视角看:config 本应是一个**纯值对象(value object)**——构造一次、不变、可复制——但上面这套设计让它变成了"带全局状态的活对象",于是 test mutation、async 边界、热更新都会互相污染。
|
||||
|
||||
## 2. 核心设计原则
|
||||
|
||||
> **Config is a value object, not live shared state.**
|
||||
> 构造一次,不可变,没有 reload。新 config = 新对象 + 重建 agent。
|
||||
|
||||
这一条原则推导出后面所有决策:
|
||||
|
||||
- 全部 config model `frozen=True` → 非法状态不可表示
|
||||
- `from_file()` 是纯函数 → 无副作用
|
||||
- 没有 "热加载"语义 → 改变配置等于"拿到新对象",由调用方决定要不要换进程全局
|
||||
|
||||
## 3. 四层分层
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph L1 ["第 1 层 数据模型 — 冻结的 ADT"]
|
||||
direction LR
|
||||
AppConfig["AppConfig frozen=True"]
|
||||
Sub["MemoryConfig TitleConfig SummarizationConfig ... 全部 frozen"]
|
||||
AppConfig --> Sub
|
||||
end
|
||||
|
||||
subgraph L2 ["第 2 层 Lifecycle — AppConfig.current"]
|
||||
direction LR
|
||||
Override["_override ContextVar per-context"]
|
||||
Global["_global ClassVar process-singleton"]
|
||||
Auto["auto-load from file with warning"]
|
||||
Override --> Global
|
||||
Global --> Auto
|
||||
end
|
||||
|
||||
subgraph L3 ["第 3 层 Per-invocation context — DeerFlowContext"]
|
||||
direction LR
|
||||
Ctx["frozen dataclass app_config thread_id agent_name"]
|
||||
Resolve["resolve_context legacy bridge"]
|
||||
Ctx --> Resolve
|
||||
end
|
||||
|
||||
subgraph L4 ["第 4 层 访问模式 — 按 caller 类型分流"]
|
||||
direction LR
|
||||
Typed["typed middleware runtime.context.app_config.xxx"]
|
||||
Legacy["dict-legacy resolve_context runtime"]
|
||||
NonAgent["非 agent 路径 AppConfig.current"]
|
||||
end
|
||||
|
||||
L1 --> L2
|
||||
L2 --> L3
|
||||
L3 --> L4
|
||||
|
||||
classDef morandiBlue fill:#B5C4D1,stroke:#6A7A8C,color:#2E3A47
|
||||
classDef morandiGreen fill:#C4D1B5,stroke:#7A8C6A,color:#2E3A47
|
||||
classDef morandiPurple fill:#C9BED1,stroke:#7E6A8C,color:#2E3A47
|
||||
classDef morandiGrey fill:#CFCFCF,stroke:#7A7A7A,color:#2E3A47
|
||||
class L1 morandiBlue
|
||||
class L2 morandiGreen
|
||||
class L3 morandiPurple
|
||||
class L4 morandiGrey
|
||||
```
|
||||
|
||||
### 3.1 第 1 层:冻结的 ADT
|
||||
|
||||
所有 config model 都是 Pydantic `frozen=True`。
|
||||
|
||||
```python
|
||||
class MemoryConfig(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
enabled: bool = True
|
||||
storage_path: str | None = None
|
||||
...
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow", frozen=True)
|
||||
memory: MemoryConfig
|
||||
title: TitleConfig
|
||||
...
|
||||
```
|
||||
|
||||
改 config 用 copy-on-write:
|
||||
|
||||
```python
|
||||
new_config = config.model_copy(update={"memory": new_memory_config})
|
||||
```
|
||||
|
||||
**从类型论视角**:这就是个 product type(record),所有字段组合起来才是一个完整的 `AppConfig`。冻结意味着 `AppConfig` 是**指称透明**的——同样的输入永远拿到同样的对象。
|
||||
|
||||
### 3.2 第 2 层:Lifecycle — `AppConfig.current()`
|
||||
|
||||
这层是整个设计最值得讲的一块。它不是一个简单的单 `ContextVar`,而是**三层 fallback**:
|
||||
|
||||
```python
|
||||
class AppConfig(BaseModel):
|
||||
...
|
||||
|
||||
# 进程级单例。GIL 下原子指针交换,无需锁
|
||||
_global: ClassVar[AppConfig | None] = None
|
||||
|
||||
# Per-context override,用于测试隔离和多 client
|
||||
_override: ClassVar[ContextVar[AppConfig]] = ContextVar("deerflow_app_config_override")
|
||||
|
||||
@classmethod
|
||||
def init(cls, config: AppConfig) -> None:
|
||||
"""设置进程全局。对所有后续 async task 可见"""
|
||||
cls._global = config
|
||||
|
||||
@classmethod
|
||||
def set_override(cls, config: AppConfig) -> Token[AppConfig]:
|
||||
"""Per-context 覆盖。返回 Token 给 reset_override()"""
|
||||
return cls._override.set(config)
|
||||
|
||||
@classmethod
|
||||
def reset_override(cls, token: Token[AppConfig]) -> None:
|
||||
cls._override.reset(token)
|
||||
|
||||
@classmethod
|
||||
def current(cls) -> AppConfig:
|
||||
"""优先级:per-context override > 进程全局 > 自动从文件加载(warning)"""
|
||||
try:
|
||||
return cls._override.get()
|
||||
except LookupError:
|
||||
pass
|
||||
if cls._global is not None:
|
||||
return cls._global
|
||||
logger.warning("AppConfig.current() called before init(); auto-loading from file. ...")
|
||||
config = cls.from_file()
|
||||
cls._global = config
|
||||
return config
|
||||
```
|
||||
|
||||
**为什么是三层,不是一层?**
|
||||
|
||||
| 原因 | 解释 |
|
||||
|------|------|
|
||||
| 单 ContextVar 行不通 | Gateway 收到 `PUT /mcp/config` reload config,下一个请求在**全新的 async context** 里跑——ContextVar 的值传不过去。只能用进程级变量 |
|
||||
| 保留 ContextVar override | 测试需要 per-test scope config,`Token`-based reset 保证干净恢复。多 client 场景如果真出现也能靠它 |
|
||||
| Auto-load fallback | 有些 call site 历史上没调 `init()`(内部脚本、import-time 触发的测试)。加 warning 保证信号不丢,但不硬崩 |
|
||||
|
||||
**Scala 视角的映射**:
|
||||
|
||||
- `_global` = 进程级 `var`,脏,但别无选择
|
||||
- `_override` = `Option[ContextVar]` 形式的 reader monad 层
|
||||
- `current()` = fallback chain `override.orElse(global).orElse(autoLoad)`,和 `Option.orElse` 思路一致
|
||||
|
||||
**为什么 `_global` 没加锁?**
|
||||
|
||||
因为读和写都是单个指针赋值(assignment of class attribute),在 CPython 的 GIL 下是原子的。如果将来改成 read-modify-write(比如 "如果没 init 就 init 成 X"),再加 `threading.Lock`。现在不加是因为——不需要。
|
||||
|
||||
### 3.3 第 3 层:`DeerFlowContext` — per-invocation typed context
|
||||
|
||||
```python
|
||||
# deerflow/config/deer_flow_context.py
|
||||
@dataclass(frozen=True)
|
||||
class DeerFlowContext:
|
||||
"""Typed, immutable, per-invocation context injected via LangGraph Runtime"""
|
||||
app_config: AppConfig
|
||||
thread_id: str
|
||||
agent_name: str | None = None
|
||||
```
|
||||
|
||||
为什么不把 `thread_id` 也放进 `AppConfig`?
|
||||
|
||||
- `AppConfig` 是**配置**——进程启动时确定,所有请求共享
|
||||
- `thread_id` 是**每次调用变的运行时身份**——必须 per-invocation
|
||||
|
||||
两者是不同的 category,混在一起就是把静态配置和动态 identity 耦合。
|
||||
|
||||
**注入路径**:
|
||||
|
||||
```python
|
||||
# Gateway worker(主路径)
|
||||
deer_flow_context = DeerFlowContext(
|
||||
app_config=AppConfig.current(),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
agent.astream(input, config=config, context=deer_flow_context)
|
||||
|
||||
# DeerFlowClient
|
||||
AppConfig.init(AppConfig.from_file(config_path))
|
||||
context = DeerFlowContext(app_config=AppConfig.current(), thread_id=thread_id)
|
||||
agent.stream(input, config=config, context=context)
|
||||
```
|
||||
|
||||
LangGraph 的 `Runtime` 会把 `context=...` 的值注入到 `Runtime[DeerFlowContext].context` 里。Middleware 拿到的就是 typed 的 `DeerFlowContext`。
|
||||
|
||||
**不进 context 的东西**:`sandbox_id`——它是 mid-execution 才 acquire 的**可变运行时状态**,正确的归宿是 `ThreadState.sandbox`(state channel,有 reducer),不是 context。原先 `sandbox/tools.py` 里 3 处 `runtime.context["sandbox_id"] = ...` 的写法全部删除。
|
||||
|
||||
### 3.4 第 4 层:访问模式按 caller 类型分流
|
||||
|
||||
三种 caller,三种模式:
|
||||
|
||||
| Caller 类型 | 访问模式 | 例子 |
|
||||
|-------------|----------|------|
|
||||
| Typed middleware(签名写 `Runtime[DeerFlowContext]`) | `runtime.context.app_config.xxx` 直读,无包装 | `memory_middleware` / `title_middleware` / `thread_data_middleware` 等 |
|
||||
| 可能遇到 dict context 的 tool | `resolve_context(runtime).xxx` | `sandbox/tools.py`(dict-legacy 路径)/ `task_tool.py`(bash subagent gate) |
|
||||
| 非 agent 路径(Gateway router、CLI、factory) | `AppConfig.current().xxx` | `app/gateway/routers/*` / `reset_admin.py` / `models/factory.py` |
|
||||
|
||||
**关键简化**(commit `a934a822`):原本所有 middleware 都走 `resolve_context()`,后来发现既然签名已经是 `Runtime[DeerFlowContext]`,包装就是冗余防御,直接 `runtime.context.app_config.xxx` 就行。同时也把 `title_middleware` 里每个 helper 的 `title_config=None` fallback 都删掉了——**required parameter 不给 default**,让类型系统强制 caller 传对。
|
||||
|
||||
这对应 Scala / FP 的两个信条:
|
||||
- **让非法状态不可表示**(`Option[TitleConfig]` 改成 `TitleConfig` required)
|
||||
- **Let-it-crash**(config 解析失败是真 bug,surface 出来比吞掉退化更好)
|
||||
|
||||
## 4. `resolve_context()` 的三种分支
|
||||
|
||||
`resolve_context()` 自己还在,处理三种 runtime.context 形状:
|
||||
|
||||
```python
|
||||
def resolve_context(runtime: Any) -> DeerFlowContext:
|
||||
ctx = getattr(runtime, "context", None)
|
||||
|
||||
# 1. typed 路径(Gateway、Client)— 直接返回
|
||||
if isinstance(ctx, DeerFlowContext):
|
||||
return ctx
|
||||
|
||||
# 2. dict-legacy 路径(老测试、第三方 invoke)— 桥接
|
||||
if isinstance(ctx, dict):
|
||||
thread_id = ctx.get("thread_id", "")
|
||||
if not thread_id:
|
||||
logger.warning("...empty thread_id...")
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig.current(),
|
||||
thread_id=thread_id,
|
||||
agent_name=ctx.get("agent_name"),
|
||||
)
|
||||
|
||||
# 3. 完全没 context — fall back 到 LangGraph configurable
|
||||
cfg = get_config().get("configurable", {})
|
||||
return DeerFlowContext(
|
||||
app_config=AppConfig.current(),
|
||||
thread_id=cfg.get("thread_id", ""),
|
||||
agent_name=cfg.get("agent_name"),
|
||||
)
|
||||
```
|
||||
|
||||
空 thread_id 会 warn,不会硬崩——在这里 warn 比 crash 合理,因为 `thread_id` 缺失只影响文件路径(落到空字符串目录),不会让整个 agent 跑崩。
|
||||
|
||||
## 5. Gateway config 热更新流程
|
||||
|
||||
历史上 Gateway 用 `reload_*_config()` 带 mtime 检测。现在改成:
|
||||
|
||||
```
|
||||
写 extensions_config.json → AppConfig.init(AppConfig.from_file()) → 下一个请求看到新值
|
||||
```
|
||||
|
||||
**没有**:mtime 检测、自动刷新、`reload_*()` 函数。
|
||||
|
||||
哲学很简单:**结构性变化(模型、tools、middleware 链)需要重建 agent;运行时变化(`memory.enabled` 这种 flag)下一次 invocation 从 `AppConfig.current()` 取值就自动生效**。不需要给 config 做"活对象"语义。
|
||||
|
||||
## 6. 从原计划的分歧
|
||||
|
||||
三处关键分歧(详情见 [归档 spec §7](./plans/2026-04-12-config-refactor-design.md#7-divergence-from-original-plan)):
|
||||
|
||||
| 分歧 | 原计划 | Shipped | 原因 |
|
||||
|------|--------|---------|------|
|
||||
| Lifecycle 存储 | 单 ContextVar,`ConfigNotInitializedError` 硬崩 | 3 层 fallback,auto-load + warning | ContextVar 跨 async 边界传不过去 |
|
||||
| 模块位置 | 新建 `context.py` | Lifecycle 放在 `AppConfig` 自身 classmethod | 减一层模块耦合 |
|
||||
| Middleware 访问 | 处处 `resolve_context()` | typed middleware 直读 `runtime.context.xxx` | 类型收紧后防御性包装是 noise |
|
||||
|
||||
## 7. 从 Scala / Actor 视角的几点观察
|
||||
|
||||
- **`AppConfig` 就是个 case class / ADT**。`frozen=True` 相当于 Scala 的 final case class:构造完就不动。改动靠 `model_copy(update=…)`,对应 Scala 的 `copy(…)`。
|
||||
- **`DeerFlowContext` 是 typed reader**。Middleware 接收 `Runtime[DeerFlowContext]`,本质是 `Kleisli[DeerFlowContext, State, Result]`——依赖注入,类型化。比 `RunnableConfig.configurable: dict[str, Any]` 强太多。
|
||||
- **`resolve_context()` 是适配层**。存在是因为有三种不同形状的上游输入;在纯 FP 眼里这是个 `X => DeerFlowContext` 的 total function,通过 pattern match 三种 case 把世界收敛回 typed 的那条路径。
|
||||
- **Let-it-crash 的体现**:commit `a934a822` 干掉 middleware 里 `try/except resolve_context(...)`,干掉 `TitleConfig | None` 的 defensive fallback。Config 解析失败就让它抛出去,别吞成"degraded mode"——actor supervision 会处理,吞错反而藏 bug。
|
||||
- **进程 global 的妥协**:`_global: ClassVar` 是这套设计里唯一违背纯值的地方。但在 Python async + HTTP server 的语境里,你没别的办法跨 request 把"新 config"传给所有 task。承认妥协、限制范围(只在 lifecycle 层一个变量)、周边全部 immutable——这就是工程意义上的"合理妥协"。
|
||||
|
||||
## 8. Cheat sheet
|
||||
|
||||
想访问 config,怎么办?按你写代码的位置看:
|
||||
|
||||
| 我在写什么 | 用什么 |
|
||||
|------------|--------|
|
||||
| Typed middleware(签名 `Runtime[DeerFlowContext]`) | `runtime.context.app_config.xxx` |
|
||||
| Typed tool(`ToolRuntime[DeerFlowContext]`) | `runtime.context.xxx` |
|
||||
| 可能被老调用方以 dict context 调到的 tool | `resolve_context(runtime).xxx` |
|
||||
| Gateway router、CLI、factory、测试 helper | `AppConfig.current().xxx` |
|
||||
| 启动时初始化 | `AppConfig.init(AppConfig.from_file(path))` |
|
||||
| 测试里想临时改 config | `token = AppConfig.set_override(cfg)` / `AppConfig.reset_override(token)` |
|
||||
| Gateway 写完新 `extensions_config.json` 之后 | `AppConfig.init(AppConfig.from_file())`,然后让 agent 重建(如果结构变了) |
|
||||
|
||||
不要:
|
||||
- ~~`get_memory_config()` / `get_title_config()` 等旧 getter~~(已删)
|
||||
- ~~`reload_app_config()` / `reset_app_config()`~~(已删)
|
||||
- ~~`_memory_config` 等模块级 global~~(已删)
|
||||
- ~~`runtime.context["sandbox_id"] = ...`~~(走 `runtime.state["sandbox"]`)
|
||||
- ~~防御性 `try/except resolve_context(...)`~~(让它崩)
|
||||
Loading…
x
Reference in New Issue
Block a user