mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(config): migrate gateway routers and channels to Depends(get_config)
Phase 2 Task P2-2 (Category G): replace AppConfig.current() with the typed Depends(get_config) FastAPI dependency in every gateway router. - routers/models.py: list_models / get_model take config via Depends - routers/mcp.py: get_mcp_configuration / update_mcp_configuration via Depends; reload path now swaps app.state.config alongside AppConfig.init() so both the new primitive and legacy current() callers see the fresh config - routers/memory.py: get_memory_config_endpoint / get_memory_status via Depends - routers/skills.py: update_skill via Depends; reload swaps app.state.config - deps.py: get_run_context and langgraph_runtime read from app.state.config instead of calling AppConfig.current() - auth/reset_admin.py: CLI constructs AppConfig.from_file() explicitly at the top (it is a standalone entry point, not a request handler) - channels/service.py: from_app_config accepts optional AppConfig parameter; legacy fallback to AppConfig.current() preserved until P2-10 Test fix: test_update_skill_refreshes_prompt_cache_before_return now sets app.state.config on the test FastAPI instance so Depends(get_config) resolves. All 2379+ tests pass (one pre-existing flaky test_client_e2e unrelated).
This commit is contained in:
parent
c45157e067
commit
70323e052a
@ -4,13 +4,16 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||
from app.channels.message_bus import MessageBus
|
||||
from app.channels.store import ChannelStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Channel name → import path for lazy loading
|
||||
@ -64,14 +67,20 @@ class ChannelService:
|
||||
self._running = False
|
||||
|
||||
@classmethod
|
||||
def from_app_config(cls) -> ChannelService:
|
||||
"""Create a ChannelService from the application config."""
|
||||
from deerflow.config.app_config import AppConfig
|
||||
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
||||
"""Create a ChannelService from the application config.
|
||||
|
||||
config = AppConfig.current()
|
||||
Pass ``app_config`` explicitly when available (e.g. from Gateway
|
||||
startup). Falls back to ``AppConfig.current()`` for legacy callers;
|
||||
that fallback is removed in Phase 2 task P2-10.
|
||||
"""
|
||||
if app_config is None:
|
||||
from deerflow.config.app_config import AppConfig as _AppConfig
|
||||
|
||||
app_config = _AppConfig.current()
|
||||
channels_config = {}
|
||||
# extra fields are allowed by AppConfig (extra="allow")
|
||||
extra = config.model_extra or {}
|
||||
extra = app_config.model_extra or {}
|
||||
if "channels" in extra:
|
||||
channels_config = extra["channels"]
|
||||
return cls(channels_config=channels_config)
|
||||
|
||||
@ -32,7 +32,8 @@ async def _run(email: str | None) -> int:
|
||||
init_engine_from_config,
|
||||
)
|
||||
|
||||
config = AppConfig.current()
|
||||
# CLI entry: load config explicitly at the top, pass down through the closure.
|
||||
config = AppConfig.from_file()
|
||||
await init_engine_from_config(config.database)
|
||||
try:
|
||||
sf = get_session_factory()
|
||||
|
||||
@ -47,7 +47,6 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.config import AppConfig
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
@ -58,7 +57,8 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
# Initialize persistence engine BEFORE checkpointer so that
|
||||
# auto-create-database logic runs first (postgres backend).
|
||||
config = AppConfig.current()
|
||||
# Use app.state.config which was populated earlier in lifespan().
|
||||
config = app.state.config
|
||||
await init_engine_from_config(config.database)
|
||||
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
@ -142,13 +142,12 @@ def get_run_context(request: Request) -> RunContext:
|
||||
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
||||
to :func:`run_agent`.
|
||||
"""
|
||||
from deerflow.config import AppConfig
|
||||
|
||||
config = get_config(request)
|
||||
return RunContext(
|
||||
checkpointer=get_checkpointer(request),
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(AppConfig.current(), "run_events", None),
|
||||
run_events_config=getattr(config, "run_events", None),
|
||||
thread_store=get_thread_store(request),
|
||||
)
|
||||
|
||||
|
||||
@ -3,9 +3,10 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
@ -70,7 +71,7 @@ class McpConfigUpdateRequest(BaseModel):
|
||||
summary="Get MCP Configuration",
|
||||
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
|
||||
)
|
||||
async def get_mcp_configuration() -> McpConfigResponse:
|
||||
async def get_mcp_configuration(config: AppConfig = Depends(get_config)) -> McpConfigResponse:
|
||||
"""Get the current MCP configuration.
|
||||
|
||||
Returns:
|
||||
@ -91,7 +92,7 @@ async def get_mcp_configuration() -> McpConfigResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
ext = AppConfig.current().extensions
|
||||
ext = config.extensions
|
||||
|
||||
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in ext.mcp_servers.items()})
|
||||
|
||||
@ -102,7 +103,11 @@ async def get_mcp_configuration() -> McpConfigResponse:
|
||||
summary="Update MCP Configuration",
|
||||
description="Update Model Context Protocol (MCP) server configurations and save to file.",
|
||||
)
|
||||
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
|
||||
async def update_mcp_configuration(
|
||||
request: McpConfigUpdateRequest,
|
||||
http_request: Request,
|
||||
config: AppConfig = Depends(get_config),
|
||||
) -> McpConfigResponse:
|
||||
"""Update the MCP configuration.
|
||||
|
||||
This will:
|
||||
@ -143,8 +148,8 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
|
||||
config_path = Path.cwd().parent / "extensions_config.json"
|
||||
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
|
||||
|
||||
# Load current config to preserve skills configuration
|
||||
current_ext = AppConfig.current().extensions
|
||||
# Use injected config to preserve skills configuration
|
||||
current_ext = config.extensions
|
||||
|
||||
# Convert request to dict format for JSON serialization
|
||||
config_data = {
|
||||
@ -161,10 +166,13 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
|
||||
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
|
||||
# will detect config file changes via mtime and reinitialize MCP tools automatically
|
||||
|
||||
# Reload the configuration and update the global cache
|
||||
AppConfig.init(AppConfig.from_file())
|
||||
reloaded_ext = AppConfig.current().extensions
|
||||
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_ext.mcp_servers.items()})
|
||||
# Reload the configuration. Swap app.state.config (new primitive) and
|
||||
# AppConfig.init() (legacy) so both Depends(get_config) and the not-yet-migrated
|
||||
# AppConfig.current() callers see the new config.
|
||||
reloaded = AppConfig.from_file()
|
||||
http_request.app.state.config = reloaded
|
||||
AppConfig.init(reloaded)
|
||||
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded.extensions.mcp_servers.items()})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
"""Memory API router for retrieving and managing global memory data."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.agents.memory.updater import (
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
@ -295,7 +296,9 @@ async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||
summary="Get Memory Configuration",
|
||||
description="Retrieve the current memory system configuration.",
|
||||
)
|
||||
async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
async def get_memory_config_endpoint(
|
||||
app_config: AppConfig = Depends(get_config),
|
||||
) -> MemoryConfigResponse:
|
||||
"""Get the memory system configuration.
|
||||
|
||||
Returns:
|
||||
@ -314,7 +317,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = AppConfig.current().memory
|
||||
config = app_config.memory
|
||||
return MemoryConfigResponse(
|
||||
enabled=config.enabled,
|
||||
storage_path=config.storage_path,
|
||||
@ -333,13 +336,15 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
summary="Get Memory Status",
|
||||
description="Retrieve both memory configuration and current data in a single request.",
|
||||
)
|
||||
async def get_memory_status() -> MemoryStatusResponse:
|
||||
async def get_memory_status(
|
||||
app_config: AppConfig = Depends(get_config),
|
||||
) -> MemoryStatusResponse:
|
||||
"""Get the memory system status including configuration and data.
|
||||
|
||||
Returns:
|
||||
Combined memory configuration and current data.
|
||||
"""
|
||||
config = AppConfig.current().memory
|
||||
config = app_config.memory
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
|
||||
return MemoryStatusResponse(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["models"])
|
||||
@ -29,7 +30,7 @@ class ModelsListResponse(BaseModel):
|
||||
summary="List All Models",
|
||||
description="Retrieve a list of all available AI models configured in the system.",
|
||||
)
|
||||
async def list_models() -> ModelsListResponse:
|
||||
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
|
||||
"""List all available models from configuration.
|
||||
|
||||
Returns model information suitable for frontend display,
|
||||
@ -58,7 +59,6 @@ async def list_models() -> ModelsListResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = AppConfig.current()
|
||||
models = [
|
||||
ModelResponse(
|
||||
name=model.name,
|
||||
@ -79,7 +79,7 @@ async def list_models() -> ModelsListResponse:
|
||||
summary="Get Model Details",
|
||||
description="Retrieve detailed information about a specific AI model by its name.",
|
||||
)
|
||||
async def get_model(model_name: str) -> ModelResponse:
|
||||
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
|
||||
"""Get a specific model by name.
|
||||
|
||||
Args:
|
||||
@ -101,7 +101,6 @@ async def get_model(model_name: str) -> ModelResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = AppConfig.current()
|
||||
model = config.get_model_config(model_name)
|
||||
if model is None:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
@ -3,9 +3,10 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.config.app_config import AppConfig
|
||||
@ -313,7 +314,12 @@ async def get_skill(skill_name: str) -> SkillResponse:
|
||||
summary="Update Skill",
|
||||
description="Update a skill's enabled status by modifying the extensions_config.json file.",
|
||||
)
|
||||
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
|
||||
async def update_skill(
|
||||
skill_name: str,
|
||||
request: SkillUpdateRequest,
|
||||
http_request: Request,
|
||||
app_config: AppConfig = Depends(get_config),
|
||||
) -> SkillResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == skill_name), None)
|
||||
@ -326,7 +332,7 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
|
||||
config_path = Path.cwd().parent / "extensions_config.json"
|
||||
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
|
||||
|
||||
ext = AppConfig.current().extensions
|
||||
ext = app_config.extensions
|
||||
ext.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
|
||||
|
||||
config_data = {
|
||||
@ -338,7 +344,11 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
logger.info(f"Skills configuration updated and saved to: {config_path}")
|
||||
AppConfig.init(AppConfig.from_file())
|
||||
# Swap both app.state.config and AppConfig._global so Depends(get_config)
|
||||
# and legacy AppConfig.current() callers see the new config.
|
||||
reloaded = AppConfig.from_file()
|
||||
http_request.app.state.config = reloaded
|
||||
AppConfig.init(reloaded)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
|
||||
skills = load_skills(enabled_only=False)
|
||||
|
||||
@ -189,6 +189,7 @@ def test_update_skill_refreshes_prompt_cache_before_return(monkeypatch, tmp_path
|
||||
monkeypatch.setattr("app.gateway.routers.skills.refresh_skills_system_prompt_cache_async", _refresh)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.config = _app_cfg
|
||||
app.include_router(skills_router.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user