refactor(config): simplify runtime context access in middlewares and tools

- title_middleware: drop the _resolve_title_config() try/except wrapper and the
  optional title_config=None fallback on every helper. after_model/aafter_model
  read runtime.context.app_config.title directly; helpers take TitleConfig as a
  required parameter. Matches the typed Runtime[DeerFlowContext] signature.
- memory_middleware: drop resolve_context() call; use runtime.context directly
  since the type is already declared.
- sandbox/tools.py: drop three layers of try/except Exception around
  resolve_context(runtime).app_config.sandbox. If the config can't be resolved
  that's a real bug that should surface, not be swallowed with a default.
- task_tool.py: same — drop the try/except around resolve_context().
- client.py: drop the set_override() call in __init__ and _reload_config().
  It was leaking overrides across test boundaries and the leak-free path
  (init() alone) is enough for the single-Client case.
- conftest: autouse fixture that initializes a minimal AppConfig for every
  test so AppConfig.current() doesn't try to auto-load config.yaml.
- test_title_middleware_core_logic: pass TitleConfig explicitly to helpers
  instead of patching AppConfig.current globally.
This commit is contained in:
greatmengqi 2026-04-16 13:30:10 +08:00
parent faec3bf9f2
commit a934a822df
7 changed files with 152 additions and 187 deletions

View File

@ -9,7 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@ -203,12 +203,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns:
None (no state changes needed from this middleware).
"""
ctx = resolve_context(runtime)
memory_config = ctx.app_config.memory
memory_config = runtime.context.app_config.memory
if not memory_config.enabled:
return None
thread_id = ctx.thread_id
thread_id = runtime.context.thread_id
if not thread_id:
logger.debug("No thread_id in context, skipping memory update")
return None

View File

@ -8,8 +8,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.title_config import TitleConfig
from deerflow.models import create_chat_model
@ -46,10 +45,8 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return ""
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> bool:
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig) -> bool:
"""Check if we should generate a title for this thread."""
if title_config is None:
title_config = AppConfig.current().title
if not title_config.enabled:
return False
@ -69,13 +66,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
# Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> tuple[str, str]:
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig) -> tuple[str, str]:
"""Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
if title_config is None:
title_config = AppConfig.current().title
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
@ -91,17 +86,13 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
)
return prompt, user_msg
def _parse_title(self, content: object, title_config: TitleConfig | None = None) -> str:
def _parse_title(self, content: object, title_config: TitleConfig) -> str:
"""Normalize model output into a clean title string."""
if title_config is None:
title_config = AppConfig.current().title
title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'")
return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
def _fallback_title(self, user_msg: str, title_config: TitleConfig | None = None) -> str:
if title_config is None:
title_config = AppConfig.current().title
def _fallback_title(self, user_msg: str, title_config: TitleConfig) -> str:
fallback_chars = min(title_config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
@ -121,20 +112,16 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None:
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call."""
if title_config is None:
title_config = AppConfig.current().title
if not self._should_generate_title(state, title_config):
return None
_, user_msg = self._build_title_prompt(state, title_config)
return {"title": self._fallback_title(user_msg, title_config)}
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None:
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
"""Generate a title asynchronously and fall back locally on failure."""
if title_config is None:
title_config = AppConfig.current().title
if not self._should_generate_title(state, title_config):
return None
@ -153,21 +140,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg, title_config)}
def _resolve_title_config(self, runtime: Runtime[DeerFlowContext]) -> TitleConfig | None:
"""Resolve TitleConfig from the runtime context when possible.
Returns None on any failure so callers can fall back to
``AppConfig.current()`` without raising.
"""
try:
return resolve_context(runtime).app_config.title
except Exception:
return None
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._generate_title_result(state, self._resolve_title_config(runtime))
return self._generate_title_result(state, runtime.context.app_config.title)
@override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return await self._agenerate_title_result(state, self._resolve_title_config(runtime))
return await self._agenerate_title_result(state, runtime.context.app_config.title)

View File

@ -143,12 +143,8 @@ class DeerFlowClient:
middlewares: Optional list of custom middlewares to inject into the agent.
"""
if config_path is not None:
config = AppConfig.from_file(config_path)
AppConfig.init(config)
AppConfig.init(AppConfig.from_file(config_path))
self._app_config = AppConfig.current()
# Scope this client's config to the current context so it doesn't
# leak into unrelated async tasks when multiple clients coexist.
self._config_token = AppConfig.set_override(self._app_config)
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@ -177,11 +173,9 @@ class DeerFlowClient:
self._agent_config_key = None
def _reload_config(self) -> None:
"""Reload config from file, update both process-global and context override."""
config = AppConfig.from_file()
AppConfig.init(config)
self._app_config = config
self._config_token = AppConfig.set_override(config)
"""Reload config from file and refresh the cached reference."""
AppConfig.init(AppConfig.from_file())
self._app_config = AppConfig.current()
# ------------------------------------------------------------------
# Internal helpers

View File

@ -987,12 +987,11 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
"""
try:
sandbox = ensure_sandbox_initialized(runtime)
app_config = resolve_context(runtime).app_config
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
if is_local_sandbox(runtime):
try:
host_bash_config = resolve_context(runtime).app_config
except Exception:
host_bash_config = None
if not is_host_bash_allowed(host_bash_config):
if not is_host_bash_allowed(app_config):
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
ensure_thread_directories_exist(runtime)
thread_data = get_thread_data(runtime)
@ -1000,18 +999,8 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
command = replace_virtual_paths_in_command(command, thread_data)
command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command)
try:
sandbox_cfg = resolve_context(runtime).app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
ensure_thread_directories_exist(runtime)
try:
sandbox_cfg = resolve_context(runtime).app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
except SandboxError as e:
return f"Error: {e}"
@ -1047,11 +1036,8 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
if not children:
return "(empty)"
output = "\n".join(children)
try:
sandbox_cfg = resolve_context(runtime).app_config.sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
sandbox_cfg = resolve_context(runtime).app_config.sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
return _truncate_ls_output(output, max_chars)
except SandboxError as e:
return f"Error: {e}"
@ -1218,11 +1204,8 @@ def read_file_tool(
return "(empty)"
if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
try:
sandbox_cfg = resolve_context(runtime).app_config.sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
except Exception:
max_chars = 50000
sandbox_cfg = resolve_context(runtime).app_config.sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
return _truncate_read_file_output(content, max_chars)
except SandboxError as e:
return f"Error: {e}"

View File

@ -67,13 +67,8 @@ async def task_tool(
if config is None:
available = ", ".join(available_subagent_names)
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
if subagent_type == "bash":
try:
host_bash_config = resolve_context(runtime).app_config
except Exception:
host_bash_config = None
if not is_host_bash_allowed(host_bash_config):
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
if subagent_type == "bash" and not is_host_bash_allowed(resolve_context(runtime).app_config):
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
# Build config overrides
overrides: dict = {}

View File

@ -68,6 +68,28 @@ def provisioner_module():
# context should mark themselves ``@pytest.mark.no_auto_user``.
@pytest.fixture(autouse=True)
def _auto_app_config():
"""Initialize a minimal AppConfig for tests so ``AppConfig.current()`` never tries to auto-load config.yaml.
Individual tests can still override via ``patch.object(AppConfig, "current", ...)``
or by calling ``AppConfig.init()`` with a different config.
"""
try:
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
except ImportError:
yield
return
previous_global = AppConfig._global
AppConfig._global = AppConfig(sandbox=SandboxConfig(use="test"))
try:
yield
finally:
AppConfig._global = previous_global
@pytest.fixture(autouse=True)
def _auto_user_context(request):
"""Inject a default ``test-user-autouse`` into the contextvar.

View File

@ -1,173 +1,169 @@
"""Core behavior tests for TitleMiddleware."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares import title_middleware as title_middleware_module
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.title_config import TitleConfig
def _make_config(**title_overrides) -> AppConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
def _make_title_config(**overrides) -> TitleConfig:
return TitleConfig(**overrides)
def _patch_app_config(**title_overrides):
return patch.object(AppConfig, "current", return_value=_make_config(**title_overrides))
def _make_runtime(**title_overrides) -> SimpleNamespace:
"""Build a runtime whose context carries a DeerFlowContext with the given TitleConfig."""
app_config = AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
ctx = DeerFlowContext(app_config=app_config, thread_id="t1")
return SimpleNamespace(context=ctx)
class TestTitleMiddlewareCoreLogic:
def test_should_generate_title_for_first_complete_exchange(self):
with _patch_app_config(enabled=True):
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="帮我总结这段代码"),
AIMessage(content="好的,我先看结构"),
]
}
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="帮我总结这段代码"),
AIMessage(content="好的,我先看结构"),
]
}
assert middleware._should_generate_title(state) is True
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is True
def test_should_not_generate_title_when_disabled_or_already_set(self):
middleware = TitleMiddleware()
with _patch_app_config(enabled=False):
disabled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": None,
}
assert middleware._should_generate_title(disabled_state) is False
disabled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": None,
}
assert middleware._should_generate_title(disabled_state, _make_title_config(enabled=False)) is False
with _patch_app_config(enabled=True):
titled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": "Existing Title",
}
assert middleware._should_generate_title(titled_state) is False
titled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": "Existing Title",
}
assert middleware._should_generate_title(titled_state, _make_title_config(enabled=True)) is False
def test_should_not_generate_title_after_second_user_turn(self):
with _patch_app_config(enabled=True):
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="第一问"),
AIMessage(content="第一答"),
HumanMessage(content="第二问"),
AIMessage(content="第二答"),
]
}
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="第一问"),
AIMessage(content="第一答"),
HumanMessage(content="第二问"),
AIMessage(content="第二答"),
]
}
assert middleware._should_generate_title(state) is False
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is False
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
with _patch_app_config(max_chars=12):
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
AIMessage(content="好的,先确认需求"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
state = {
"messages": [
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
AIMessage(content="好的,先确认需求"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=12)))
title = result["title"]
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
model.ainvoke.assert_awaited_once()
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
model.ainvoke.assert_awaited_once()
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
with _patch_app_config(max_chars=20):
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
]
}
state = {
"messages": [
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
title = result["title"]
assert title == "请帮我总结这段代码"
assert title == "请帮我总结这段代码"
def test_generate_title_fallback_for_long_message(self, monkeypatch):
with _patch_app_config(max_chars=20):
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题"),
AIMessage(content="收到"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题"),
AIMessage(content="收到"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
title = result["title"]
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述")
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述")
def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime()))
assert result == {"title": "异步标题"}
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime())) is None
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
result = middleware.after_model({"messages": []}, runtime=MagicMock())
result = middleware.after_model({"messages": []}, runtime=_make_runtime())
assert result == {"title": "同步标题"}
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
assert middleware.after_model({"messages": []}, runtime=_make_runtime()) is None
def test_sync_generate_title_uses_fallback_without_model(self):
"""Sync path avoids LLM calls and derives a local fallback title."""
with _patch_app_config(max_chars=20):
middleware = TitleMiddleware()
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"),
]
}
result = middleware._generate_title_result(state)
assert result == {"title": "请帮我写测试"}
state = {
"messages": [
HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"),
]
}
result = middleware._generate_title_result(state, _make_title_config(max_chars=20))
assert result == {"title": "请帮我写测试"}
def test_sync_generate_title_respects_fallback_truncation(self):
"""Sync fallback path still respects max_chars truncation rules."""
with _patch_app_config(max_chars=50):
middleware = TitleMiddleware()
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题而且这里继续补充更多上下文确保超过本地fallback截断阈值"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
assert result["title"].endswith("...")
assert result["title"].startswith("这是一个非常长的问题描述")
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题而且这里继续补充更多上下文确保超过本地fallback截断阈值"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state, _make_title_config(max_chars=50))
assert result["title"].endswith("...")
assert result["title"].startswith("这是一个非常长的问题描述")