refactor(config): simplify runtime context access in middlewares and tools

- title_middleware: drop the _resolve_title_config() try/except wrapper and the
  optional title_config=None fallback on every helper. after_model/aafter_model
  read runtime.context.app_config.title directly; helpers take TitleConfig as a
  required parameter. Matches the typed Runtime[DeerFlowContext] signature.
- memory_middleware: drop resolve_context() call; use runtime.context directly
  since the type is already declared.
- sandbox/tools.py: drop three layers of try/except Exception around
  resolve_context(runtime).app_config.sandbox. If the config can't be resolved
  that's a real bug that should surface, not be swallowed with a default.
- task_tool.py: same — drop the try/except around resolve_context().
- client.py: drop the set_override() call in __init__ and _reload_config().
  It was leaking overrides across test boundaries and the leak-free path
  (init() alone) is enough for the single-Client case.
- conftest: autouse fixture that initializes a minimal AppConfig for every
  test so AppConfig.current() doesn't try to auto-load config.yaml.
- test_title_middleware_core_logic: pass TitleConfig explicitly to helpers
  instead of patching AppConfig.current globally.
This commit is contained in:
greatmengqi 2026-04-16 13:30:10 +08:00
parent faec3bf9f2
commit a934a822df
7 changed files with 152 additions and 187 deletions

View File

@ -9,7 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -203,12 +203,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns: Returns:
None (no state changes needed from this middleware). None (no state changes needed from this middleware).
""" """
ctx = resolve_context(runtime) memory_config = runtime.context.app_config.memory
memory_config = ctx.app_config.memory
if not memory_config.enabled: if not memory_config.enabled:
return None return None
thread_id = ctx.thread_id thread_id = runtime.context.thread_id
if not thread_id: if not thread_id:
logger.debug("No thread_id in context, skipping memory update") logger.debug("No thread_id in context, skipping memory update")
return None return None

View File

@ -8,8 +8,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.app_config import AppConfig from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
from deerflow.config.title_config import TitleConfig from deerflow.config.title_config import TitleConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
@ -46,10 +45,8 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return "" return ""
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> bool: def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig) -> bool:
"""Check if we should generate a title for this thread.""" """Check if we should generate a title for this thread."""
if title_config is None:
title_config = AppConfig.current().title
if not title_config.enabled: if not title_config.enabled:
return False return False
@ -69,13 +66,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
# Generate title after first complete exchange # Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1 return len(user_messages) == 1 and len(assistant_messages) >= 1
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> tuple[str, str]: def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig) -> tuple[str, str]:
"""Extract user/assistant messages and build the title prompt. """Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback. Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
""" """
if title_config is None:
title_config = AppConfig.current().title
messages = state.get("messages", []) messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "") user_msg_content = next((m.content for m in messages if m.type == "human"), "")
@ -91,17 +86,13 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
) )
return prompt, user_msg return prompt, user_msg
def _parse_title(self, content: object, title_config: TitleConfig | None = None) -> str: def _parse_title(self, content: object, title_config: TitleConfig) -> str:
"""Normalize model output into a clean title string.""" """Normalize model output into a clean title string."""
if title_config is None:
title_config = AppConfig.current().title
title_content = self._normalize_content(content) title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'") title = title_content.strip().strip('"').strip("'")
return title[: title_config.max_chars] if len(title) > title_config.max_chars else title return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
def _fallback_title(self, user_msg: str, title_config: TitleConfig | None = None) -> str: def _fallback_title(self, user_msg: str, title_config: TitleConfig) -> str:
if title_config is None:
title_config = AppConfig.current().title
fallback_chars = min(title_config.max_chars, 50) fallback_chars = min(title_config.max_chars, 50)
if len(user_msg) > fallback_chars: if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..." return user_msg[:fallback_chars].rstrip() + "..."
@ -121,20 +112,16 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
config["tags"] = [*(config.get("tags") or []), "middleware:title"] config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config return config
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None: def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call.""" """Generate a local fallback title without blocking on an LLM call."""
if title_config is None:
title_config = AppConfig.current().title
if not self._should_generate_title(state, title_config): if not self._should_generate_title(state, title_config):
return None return None
_, user_msg = self._build_title_prompt(state, title_config) _, user_msg = self._build_title_prompt(state, title_config)
return {"title": self._fallback_title(user_msg, title_config)} return {"title": self._fallback_title(user_msg, title_config)}
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None: async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
"""Generate a title asynchronously and fall back locally on failure.""" """Generate a title asynchronously and fall back locally on failure."""
if title_config is None:
title_config = AppConfig.current().title
if not self._should_generate_title(state, title_config): if not self._should_generate_title(state, title_config):
return None return None
@ -153,21 +140,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
logger.debug("Failed to generate async title; falling back to local title", exc_info=True) logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg, title_config)} return {"title": self._fallback_title(user_msg, title_config)}
def _resolve_title_config(self, runtime: Runtime[DeerFlowContext]) -> TitleConfig | None:
"""Resolve TitleConfig from the runtime context when possible.
Returns None on any failure so callers can fall back to
``AppConfig.current()`` without raising.
"""
try:
return resolve_context(runtime).app_config.title
except Exception:
return None
@override @override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None: def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._generate_title_result(state, self._resolve_title_config(runtime)) return self._generate_title_result(state, runtime.context.app_config.title)
@override @override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None: async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return await self._agenerate_title_result(state, self._resolve_title_config(runtime)) return await self._agenerate_title_result(state, runtime.context.app_config.title)

View File

@ -143,12 +143,8 @@ class DeerFlowClient:
middlewares: Optional list of custom middlewares to inject into the agent. middlewares: Optional list of custom middlewares to inject into the agent.
""" """
if config_path is not None: if config_path is not None:
config = AppConfig.from_file(config_path) AppConfig.init(AppConfig.from_file(config_path))
AppConfig.init(config)
self._app_config = AppConfig.current() self._app_config = AppConfig.current()
# Scope this client's config to the current context so it doesn't
# leak into unrelated async tasks when multiple clients coexist.
self._config_token = AppConfig.set_override(self._app_config)
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name): if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}") raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@ -177,11 +173,9 @@ class DeerFlowClient:
self._agent_config_key = None self._agent_config_key = None
def _reload_config(self) -> None: def _reload_config(self) -> None:
"""Reload config from file, update both process-global and context override.""" """Reload config from file and refresh the cached reference."""
config = AppConfig.from_file() AppConfig.init(AppConfig.from_file())
AppConfig.init(config) self._app_config = AppConfig.current()
self._app_config = config
self._config_token = AppConfig.set_override(config)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Internal helpers # Internal helpers

View File

@ -987,12 +987,11 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
""" """
try: try:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
app_config = resolve_context(runtime).app_config
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
if is_local_sandbox(runtime): if is_local_sandbox(runtime):
try: if not is_host_bash_allowed(app_config):
host_bash_config = resolve_context(runtime).app_config
except Exception:
host_bash_config = None
if not is_host_bash_allowed(host_bash_config):
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}" return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
@ -1000,18 +999,8 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
command = replace_virtual_paths_in_command(command, thread_data) command = replace_virtual_paths_in_command(command, thread_data)
command = _apply_cwd_prefix(command, thread_data) command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command) output = sandbox.execute_command(command)
try:
sandbox_cfg = resolve_context(runtime).app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars) return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
try:
sandbox_cfg = resolve_context(runtime).app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(sandbox.execute_command(command), max_chars) return _truncate_bash_output(sandbox.execute_command(command), max_chars)
except SandboxError as e: except SandboxError as e:
return f"Error: {e}" return f"Error: {e}"
@ -1047,11 +1036,8 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
if not children: if not children:
return "(empty)" return "(empty)"
output = "\n".join(children) output = "\n".join(children)
try: sandbox_cfg = resolve_context(runtime).app_config.sandbox
sandbox_cfg = resolve_context(runtime).app_config.sandbox max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_ls_output(output, max_chars) return _truncate_ls_output(output, max_chars)
except SandboxError as e: except SandboxError as e:
return f"Error: {e}" return f"Error: {e}"
@ -1218,11 +1204,8 @@ def read_file_tool(
return "(empty)" return "(empty)"
if start_line is not None and end_line is not None: if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line]) content = "\n".join(content.splitlines()[start_line - 1 : end_line])
try: sandbox_cfg = resolve_context(runtime).app_config.sandbox
sandbox_cfg = resolve_context(runtime).app_config.sandbox max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
except Exception:
max_chars = 50000
return _truncate_read_file_output(content, max_chars) return _truncate_read_file_output(content, max_chars)
except SandboxError as e: except SandboxError as e:
return f"Error: {e}" return f"Error: {e}"

View File

@ -67,13 +67,8 @@ async def task_tool(
if config is None: if config is None:
available = ", ".join(available_subagent_names) available = ", ".join(available_subagent_names)
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}" return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
if subagent_type == "bash": if subagent_type == "bash" and not is_host_bash_allowed(resolve_context(runtime).app_config):
try: return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
host_bash_config = resolve_context(runtime).app_config
except Exception:
host_bash_config = None
if not is_host_bash_allowed(host_bash_config):
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
# Build config overrides # Build config overrides
overrides: dict = {} overrides: dict = {}

View File

@ -68,6 +68,28 @@ def provisioner_module():
# context should mark themselves ``@pytest.mark.no_auto_user``. # context should mark themselves ``@pytest.mark.no_auto_user``.
@pytest.fixture(autouse=True)
def _auto_app_config():
"""Initialize a minimal AppConfig for tests so ``AppConfig.current()`` never tries to auto-load config.yaml.
Individual tests can still override via ``patch.object(AppConfig, "current", ...)``
or by calling ``AppConfig.init()`` with a different config.
"""
try:
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
except ImportError:
yield
return
previous_global = AppConfig._global
AppConfig._global = AppConfig(sandbox=SandboxConfig(use="test"))
try:
yield
finally:
AppConfig._global = previous_global
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _auto_user_context(request): def _auto_user_context(request):
"""Inject a default ``test-user-autouse`` into the contextvar. """Inject a default ``test-user-autouse`` into the contextvar.

View File

@ -1,173 +1,169 @@
"""Core behavior tests for TitleMiddleware.""" """Core behavior tests for TitleMiddleware."""
import asyncio import asyncio
from unittest.mock import AsyncMock, MagicMock, patch from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares import title_middleware as title_middleware_module from deerflow.agents.middlewares import title_middleware as title_middleware_module
from deerflow.agents.middlewares.title_middleware import TitleMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.title_config import TitleConfig from deerflow.config.title_config import TitleConfig
def _make_config(**title_overrides) -> AppConfig: def _make_title_config(**overrides) -> TitleConfig:
return AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides)) return TitleConfig(**overrides)
def _patch_app_config(**title_overrides): def _make_runtime(**title_overrides) -> SimpleNamespace:
return patch.object(AppConfig, "current", return_value=_make_config(**title_overrides)) """Build a runtime whose context carries a DeerFlowContext with the given TitleConfig."""
app_config = AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
ctx = DeerFlowContext(app_config=app_config, thread_id="t1")
return SimpleNamespace(context=ctx)
class TestTitleMiddlewareCoreLogic: class TestTitleMiddlewareCoreLogic:
def test_should_generate_title_for_first_complete_exchange(self): def test_should_generate_title_for_first_complete_exchange(self):
with _patch_app_config(enabled=True): middleware = TitleMiddleware()
middleware = TitleMiddleware() state = {
state = { "messages": [
"messages": [ HumanMessage(content="帮我总结这段代码"),
HumanMessage(content="帮我总结这段代码"), AIMessage(content="好的,我先看结构"),
AIMessage(content="好的,我先看结构"), ]
] }
}
assert middleware._should_generate_title(state) is True assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is True
def test_should_not_generate_title_when_disabled_or_already_set(self): def test_should_not_generate_title_when_disabled_or_already_set(self):
middleware = TitleMiddleware() middleware = TitleMiddleware()
with _patch_app_config(enabled=False): disabled_state = {
disabled_state = { "messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"messages": [HumanMessage(content="Q"), AIMessage(content="A")], "title": None,
"title": None, }
} assert middleware._should_generate_title(disabled_state, _make_title_config(enabled=False)) is False
assert middleware._should_generate_title(disabled_state) is False
with _patch_app_config(enabled=True): titled_state = {
titled_state = { "messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"messages": [HumanMessage(content="Q"), AIMessage(content="A")], "title": "Existing Title",
"title": "Existing Title", }
} assert middleware._should_generate_title(titled_state, _make_title_config(enabled=True)) is False
assert middleware._should_generate_title(titled_state) is False
def test_should_not_generate_title_after_second_user_turn(self): def test_should_not_generate_title_after_second_user_turn(self):
with _patch_app_config(enabled=True): middleware = TitleMiddleware()
middleware = TitleMiddleware() state = {
state = { "messages": [
"messages": [ HumanMessage(content="第一问"),
HumanMessage(content="第一问"), AIMessage(content="第一答"),
AIMessage(content="第一答"), HumanMessage(content="第二问"),
HumanMessage(content="第二问"), AIMessage(content="第二答"),
AIMessage(content="第二答"), ]
] }
}
assert middleware._should_generate_title(state) is False assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is False
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch): def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
with _patch_app_config(max_chars=12): middleware = TitleMiddleware()
middleware = TitleMiddleware() model = MagicMock()
model = MagicMock() model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题")) monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = { state = {
"messages": [ "messages": [
HumanMessage(content="请帮我写一个很长很长的脚本标题"), HumanMessage(content="请帮我写一个很长很长的脚本标题"),
AIMessage(content="好的,先确认需求"), AIMessage(content="好的,先确认需求"),
] ]
} }
result = asyncio.run(middleware._agenerate_title_result(state)) result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=12)))
title = result["title"] title = result["title"]
assert title == "短标题" assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False) title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
model.ainvoke.assert_awaited_once() model.ainvoke.assert_awaited_once()
def test_generate_title_normalizes_structured_message_content(self, monkeypatch): def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
with _patch_app_config(max_chars=20): middleware = TitleMiddleware()
middleware = TitleMiddleware() model = MagicMock()
model = MagicMock() model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码")) monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = { state = {
"messages": [ "messages": [
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]), HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]), AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
] ]
} }
result = asyncio.run(middleware._agenerate_title_result(state)) result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
title = result["title"] title = result["title"]
assert title == "请帮我总结这段代码" assert title == "请帮我总结这段代码"
def test_generate_title_fallback_for_long_message(self, monkeypatch): def test_generate_title_fallback_for_long_message(self, monkeypatch):
with _patch_app_config(max_chars=20): middleware = TitleMiddleware()
middleware = TitleMiddleware() model = MagicMock()
model = MagicMock() model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable")) monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = { state = {
"messages": [ "messages": [
HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题"), HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题"),
AIMessage(content="收到"), AIMessage(content="收到"),
] ]
} }
result = asyncio.run(middleware._agenerate_title_result(state)) result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
title = result["title"] title = result["title"]
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text. # Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...") assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述") assert title.startswith("这是一个非常长的问题描述")
def test_aafter_model_delegates_to_async_helper(self, monkeypatch): def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
middleware = TitleMiddleware() middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"})) monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime()))
assert result == {"title": "异步标题"} assert result == {"title": "异步标题"}
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None)) monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime())) is None
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch): def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
middleware = TitleMiddleware() middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"})) monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
result = middleware.after_model({"messages": []}, runtime=MagicMock()) result = middleware.after_model({"messages": []}, runtime=_make_runtime())
assert result == {"title": "同步标题"} assert result == {"title": "同步标题"}
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None)) monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None assert middleware.after_model({"messages": []}, runtime=_make_runtime()) is None
def test_sync_generate_title_uses_fallback_without_model(self): def test_sync_generate_title_uses_fallback_without_model(self):
"""Sync path avoids LLM calls and derives a local fallback title.""" """Sync path avoids LLM calls and derives a local fallback title."""
with _patch_app_config(max_chars=20): middleware = TitleMiddleware()
middleware = TitleMiddleware()
state = { state = {
"messages": [ "messages": [
HumanMessage(content="请帮我写测试"), HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"), AIMessage(content="好的"),
] ]
} }
result = middleware._generate_title_result(state) result = middleware._generate_title_result(state, _make_title_config(max_chars=20))
assert result == {"title": "请帮我写测试"} assert result == {"title": "请帮我写测试"}
def test_sync_generate_title_respects_fallback_truncation(self): def test_sync_generate_title_respects_fallback_truncation(self):
"""Sync fallback path still respects max_chars truncation rules.""" """Sync fallback path still respects max_chars truncation rules."""
with _patch_app_config(max_chars=50): middleware = TitleMiddleware()
middleware = TitleMiddleware()
state = { state = {
"messages": [ "messages": [
HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题而且这里继续补充更多上下文确保超过本地fallback截断阈值"), HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题而且这里继续补充更多上下文确保超过本地fallback截断阈值"),
AIMessage(content="回复"), AIMessage(content="回复"),
] ]
} }
result = middleware._generate_title_result(state) result = middleware._generate_title_result(state, _make_title_config(max_chars=50))
assert result["title"].endswith("...") assert result["title"].endswith("...")
assert result["title"].startswith("这是一个非常长的问题描述") assert result["title"].startswith("这是一个非常长的问题描述")