mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(config): thread AppConfig through lead agent construction
Phase 2 Task P2-4 (Category B): make_lead_agent, _build_middlewares, _resolve_model_name, _create_summarization_middleware, and build_lead_runtime_middlewares all accept AppConfig as an explicit parameter instead of looking up the process-global. Each entry point has an optional app_config parameter that falls back to AppConfig.current() for backward compatibility with the LangGraph Server registration path (which can only pass RunnableConfig). New callers — DeerFlowClient, Gateway worker — pass self._app_config / app.state.config explicitly. prompt.py's apply_prompt_template helpers still read via AppConfig.current() in defensive try/except; they are tightened in P2-10 where the fallback is removed wholesale. Test updates: _resolve_model_name and _create_summarization_middleware signature changes propagate to unit tests. _build_middlewares mocks gain the leading app_config parameter. All 233 targeted tests pass (lead_agent_prompt, lead_agent_model_resolution, lead_agent_skills, guardrail_middleware, client, client_e2e, multi_isolation).
This commit is contained in:
parent
f8738d1e3e
commit
23b424e7fc
@ -24,9 +24,8 @@ from deerflow.models import create_chat_model
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
def _resolve_model_name(app_config: AppConfig, requested_model_name: str | None = None) -> str:
|
||||
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||
app_config = AppConfig.current()
|
||||
default_model_name = app_config.models[0].name if app_config.models else None
|
||||
if default_model_name is None:
|
||||
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||
@ -39,9 +38,9 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
return default_model_name
|
||||
|
||||
|
||||
def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
def _create_summarization_middleware(app_config: AppConfig) -> SummarizationMiddleware | None:
|
||||
"""Create and configure the summarization middleware from config."""
|
||||
config = AppConfig.current().summarization
|
||||
config = app_config.summarization
|
||||
|
||||
if not config.enabled:
|
||||
return None
|
||||
@ -208,10 +207,17 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
|
||||
def _build_middlewares(
|
||||
app_config: AppConfig,
|
||||
config: RunnableConfig,
|
||||
model_name: str | None,
|
||||
agent_name: str | None = None,
|
||||
custom_middlewares: list[AgentMiddleware] | None = None,
|
||||
):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
|
||||
Args:
|
||||
app_config: Resolved application config.
|
||||
config: Runtime configuration containing configurable options like is_plan_mode.
|
||||
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
||||
custom_middlewares: Optional list of custom middlewares to inject into the chain.
|
||||
@ -219,10 +225,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
Returns:
|
||||
List of middleware instances.
|
||||
"""
|
||||
middlewares = build_lead_runtime_middlewares(lazy_init=True)
|
||||
middlewares = build_lead_runtime_middlewares(app_config=app_config, lazy_init=True)
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware()
|
||||
summarization_middleware = _create_summarization_middleware(app_config)
|
||||
if summarization_middleware is not None:
|
||||
middlewares.append(summarization_middleware)
|
||||
|
||||
@ -233,7 +239,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
middlewares.append(todo_list_middleware)
|
||||
|
||||
# Add TokenUsageMiddleware when token_usage tracking is enabled
|
||||
if AppConfig.current().token_usage.enabled:
|
||||
if app_config.token_usage.enabled:
|
||||
middlewares.append(TokenUsageMiddleware())
|
||||
|
||||
# Add TitleMiddleware
|
||||
@ -244,7 +250,6 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision.
|
||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||
app_config = AppConfig.current()
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
@ -273,11 +278,27 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
return middlewares
|
||||
|
||||
|
||||
def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
|
||||
def make_lead_agent(
|
||||
config: RunnableConfig,
|
||||
app_config: AppConfig | None = None,
|
||||
) -> CompiledStateGraph:
|
||||
"""Build the lead agent from runtime config.
|
||||
|
||||
Args:
|
||||
config: LangGraph ``RunnableConfig`` carrying per-invocation options
|
||||
(``thinking_enabled``, ``model_name``, ``is_plan_mode``, etc.).
|
||||
app_config: Resolved application config. When omitted falls back to
|
||||
:meth:`AppConfig.current`, preserving backward compatibility with
|
||||
callers that do not thread config explicitly (LangGraph Server
|
||||
registration path). New callers should pass this parameter.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
from deerflow.tools.builtins import setup_agent
|
||||
|
||||
if app_config is None:
|
||||
app_config = AppConfig.current()
|
||||
|
||||
cfg = config.get("configurable", {})
|
||||
|
||||
thinking_enabled = cfg.get("thinking_enabled", True)
|
||||
@ -294,9 +315,8 @@ def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||
|
||||
# Final model name resolution: request → agent config → global default, with fallback for unknown names
|
||||
model_name = _resolve_model_name(requested_model_name or agent_model_name)
|
||||
model_name = _resolve_model_name(app_config, requested_model_name or agent_model_name)
|
||||
|
||||
app_config = AppConfig.current()
|
||||
model_config = app_config.get_model_config(model_name)
|
||||
|
||||
if model_config is None:
|
||||
@ -336,7 +356,7 @@ def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
|
||||
middleware=_build_middlewares(config, model_name=model_name),
|
||||
middleware=_build_middlewares(app_config, config, model_name=model_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
|
||||
state_schema=ThreadState,
|
||||
context_schema=DeerFlowContext,
|
||||
@ -346,7 +366,7 @@ def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
middleware=_build_middlewares(app_config, config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
),
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
"""Tool error handling middleware and shared runtime middleware builders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
@ -11,6 +13,9 @@ from langgraph.errors import GraphBubbleUp
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||
@ -67,6 +72,7 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
def _build_runtime_middlewares(
|
||||
*,
|
||||
app_config: "AppConfig | None" = None,
|
||||
include_uploads: bool,
|
||||
include_dangling_tool_call_patch: bool,
|
||||
lazy_init: bool = True,
|
||||
@ -75,6 +81,10 @@ def _build_runtime_middlewares(
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
if app_config is None:
|
||||
app_config = AppConfig.current()
|
||||
|
||||
middlewares: list[AgentMiddleware] = [
|
||||
ThreadDataMiddleware(lazy_init=lazy_init),
|
||||
@ -94,9 +104,7 @@ def _build_runtime_middlewares(
|
||||
middlewares.append(LLMErrorHandlingMiddleware())
|
||||
|
||||
# Guardrail middleware (if configured)
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
guardrails_config = AppConfig.current().guardrails
|
||||
guardrails_config = app_config.guardrails
|
||||
if guardrails_config.enabled and guardrails_config.provider:
|
||||
import inspect
|
||||
|
||||
@ -125,9 +133,10 @@ def _build_runtime_middlewares(
|
||||
return middlewares
|
||||
|
||||
|
||||
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
def build_lead_runtime_middlewares(*, app_config: "AppConfig | None" = None, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
app_config=app_config,
|
||||
include_uploads=True,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
|
||||
@ -256,7 +256,7 @@ class DeerFlowClient:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
||||
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
|
||||
"middleware": _build_middlewares(self._app_config, config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
|
||||
"system_prompt": apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
|
||||
@ -32,7 +32,7 @@ def _make_model(name: str, *, supports_thinking: bool) -> ModelConfig:
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
|
||||
def test_resolve_model_name_falls_back_to_default(caplog):
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
_make_model("default-model", supports_thinking=False),
|
||||
@ -40,16 +40,14 @@ def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
resolved = lead_agent_module._resolve_model_name("missing-model")
|
||||
resolved = lead_agent_module._resolve_model_name(app_config, "missing-model")
|
||||
|
||||
assert resolved == "default-model"
|
||||
assert "fallback to default model 'default-model'" in caplog.text
|
||||
|
||||
|
||||
def test_resolve_model_name_uses_default_when_none(monkeypatch):
|
||||
def test_resolve_model_name_uses_default_when_none():
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
_make_model("default-model", supports_thinking=False),
|
||||
@ -57,23 +55,19 @@ def test_resolve_model_name_uses_default_when_none(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
|
||||
|
||||
resolved = lead_agent_module._resolve_model_name(None)
|
||||
resolved = lead_agent_module._resolve_model_name(app_config, None)
|
||||
|
||||
assert resolved == "default-model"
|
||||
|
||||
|
||||
def test_resolve_model_name_raises_when_no_models_configured(monkeypatch):
|
||||
def test_resolve_model_name_raises_when_no_models_configured():
|
||||
app_config = _make_app_config([])
|
||||
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="No chat models are configured",
|
||||
):
|
||||
lead_agent_module._resolve_model_name("missing-model")
|
||||
lead_agent_module._resolve_model_name(app_config, "missing-model")
|
||||
|
||||
|
||||
def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkeypatch):
|
||||
@ -83,7 +77,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
|
||||
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
|
||||
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda app_config, config, model_name, agent_name=None: [])
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
@ -130,10 +124,10 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
|
||||
AppConfig.init(app_config)
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: app_config))
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda _ac: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
|
||||
middlewares = lead_agent_module._build_middlewares(app_config, {"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
@ -161,7 +155,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware()
|
||||
middleware = lead_agent_module._create_summarization_middleware(patched)
|
||||
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
|
||||
@ -108,7 +108,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
|
||||
|
||||
# Mock dependencies
|
||||
monkeypatch.setattr(AppConfig, "current", staticmethod(lambda: MagicMock()))
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda app_config=None, x=None: "default-model")
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user