mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(config): simplify runtime context access in middlewares and tools
- title_middleware: drop the _resolve_title_config() try/except wrapper and the optional title_config=None fallback on every helper. after_model/aafter_model read runtime.context.app_config.title directly; helpers take TitleConfig as a required parameter. Matches the typed Runtime[DeerFlowContext] signature. - memory_middleware: drop resolve_context() call; use runtime.context directly since the type is already declared. - sandbox/tools.py: drop three layers of try/except Exception around resolve_context(runtime).app_config.sandbox. If the config can't be resolved that's a real bug that should surface, not be swallowed with a default. - task_tool.py: same — drop the try/except around resolve_context(). - client.py: drop the set_override() call in __init__ and _reload_config(). It was leaking overrides across test boundaries and the leak-free path (init() alone) is enough for the single-Client case. - conftest: autouse fixture that initializes a minimal AppConfig for every test so AppConfig.current() doesn't try to auto-load config.yaml. - test_title_middleware_core_logic: pass TitleConfig explicitly to helpers instead of patching AppConfig.current globally.
This commit is contained in:
parent
faec3bf9f2
commit
a934a822df
@ -9,7 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -203,12 +203,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
Returns:
|
||||
None (no state changes needed from this middleware).
|
||||
"""
|
||||
ctx = resolve_context(runtime)
|
||||
memory_config = ctx.app_config.memory
|
||||
memory_config = runtime.context.app_config.memory
|
||||
if not memory_config.enabled:
|
||||
return None
|
||||
|
||||
thread_id = ctx.thread_id
|
||||
thread_id = runtime.context.thread_id
|
||||
if not thread_id:
|
||||
logger.debug("No thread_id in context, skipping memory update")
|
||||
return None
|
||||
|
||||
@ -8,8 +8,7 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
@ -46,10 +45,8 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
return ""
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> bool:
|
||||
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
if not title_config.enabled:
|
||||
return False
|
||||
|
||||
@ -69,13 +66,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
# Generate title after first complete exchange
|
||||
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
||||
|
||||
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> tuple[str, str]:
|
||||
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig) -> tuple[str, str]:
|
||||
"""Extract user/assistant messages and build the title prompt.
|
||||
|
||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||
"""
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
@ -91,17 +86,13 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
)
|
||||
return prompt, user_msg
|
||||
|
||||
def _parse_title(self, content: object, title_config: TitleConfig | None = None) -> str:
|
||||
def _parse_title(self, content: object, title_config: TitleConfig) -> str:
|
||||
"""Normalize model output into a clean title string."""
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
title_content = self._normalize_content(content)
|
||||
title = title_content.strip().strip('"').strip("'")
|
||||
return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
|
||||
|
||||
def _fallback_title(self, user_msg: str, title_config: TitleConfig | None = None) -> str:
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
def _fallback_title(self, user_msg: str, title_config: TitleConfig) -> str:
|
||||
fallback_chars = min(title_config.max_chars, 50)
|
||||
if len(user_msg) > fallback_chars:
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
@ -121,20 +112,16 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
||||
return config
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None:
|
||||
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
if not self._should_generate_title(state, title_config):
|
||||
return None
|
||||
|
||||
_, user_msg = self._build_title_prompt(state, title_config)
|
||||
return {"title": self._fallback_title(user_msg, title_config)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None:
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if title_config is None:
|
||||
title_config = AppConfig.current().title
|
||||
if not self._should_generate_title(state, title_config):
|
||||
return None
|
||||
|
||||
@ -153,21 +140,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg, title_config)}
|
||||
|
||||
def _resolve_title_config(self, runtime: Runtime[DeerFlowContext]) -> TitleConfig | None:
|
||||
"""Resolve TitleConfig from the runtime context when possible.
|
||||
|
||||
Returns None on any failure so callers can fall back to
|
||||
``AppConfig.current()`` without raising.
|
||||
"""
|
||||
try:
|
||||
return resolve_context(runtime).app_config.title
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return self._generate_title_result(state, self._resolve_title_config(runtime))
|
||||
return self._generate_title_result(state, runtime.context.app_config.title)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return await self._agenerate_title_result(state, self._resolve_title_config(runtime))
|
||||
return await self._agenerate_title_result(state, runtime.context.app_config.title)
|
||||
|
||||
@ -143,12 +143,8 @@ class DeerFlowClient:
|
||||
middlewares: Optional list of custom middlewares to inject into the agent.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = AppConfig.from_file(config_path)
|
||||
AppConfig.init(config)
|
||||
AppConfig.init(AppConfig.from_file(config_path))
|
||||
self._app_config = AppConfig.current()
|
||||
# Scope this client's config to the current context so it doesn't
|
||||
# leak into unrelated async tasks when multiple clients coexist.
|
||||
self._config_token = AppConfig.set_override(self._app_config)
|
||||
|
||||
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
|
||||
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
|
||||
@ -177,11 +173,9 @@ class DeerFlowClient:
|
||||
self._agent_config_key = None
|
||||
|
||||
def _reload_config(self) -> None:
|
||||
"""Reload config from file, update both process-global and context override."""
|
||||
config = AppConfig.from_file()
|
||||
AppConfig.init(config)
|
||||
self._app_config = config
|
||||
self._config_token = AppConfig.set_override(config)
|
||||
"""Reload config from file and refresh the cached reference."""
|
||||
AppConfig.init(AppConfig.from_file())
|
||||
self._app_config = AppConfig.current()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
|
||||
@ -987,12 +987,11 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
app_config = resolve_context(runtime).app_config
|
||||
sandbox_cfg = app_config.sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
if is_local_sandbox(runtime):
|
||||
try:
|
||||
host_bash_config = resolve_context(runtime).app_config
|
||||
except Exception:
|
||||
host_bash_config = None
|
||||
if not is_host_bash_allowed(host_bash_config):
|
||||
if not is_host_bash_allowed(app_config):
|
||||
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
|
||||
ensure_thread_directories_exist(runtime)
|
||||
thread_data = get_thread_data(runtime)
|
||||
@ -1000,18 +999,8 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
command = replace_virtual_paths_in_command(command, thread_data)
|
||||
command = _apply_cwd_prefix(command, thread_data)
|
||||
output = sandbox.execute_command(command)
|
||||
try:
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
try:
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
@ -1047,11 +1036,8 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
if not children:
|
||||
return "(empty)"
|
||||
output = "\n".join(children)
|
||||
try:
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
|
||||
return _truncate_ls_output(output, max_chars)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
@ -1218,11 +1204,8 @@ def read_file_tool(
|
||||
return "(empty)"
|
||||
if start_line is not None and end_line is not None:
|
||||
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
||||
try:
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
|
||||
except Exception:
|
||||
max_chars = 50000
|
||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
|
||||
return _truncate_read_file_output(content, max_chars)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
@ -67,13 +67,8 @@ async def task_tool(
|
||||
if config is None:
|
||||
available = ", ".join(available_subagent_names)
|
||||
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
|
||||
if subagent_type == "bash":
|
||||
try:
|
||||
host_bash_config = resolve_context(runtime).app_config
|
||||
except Exception:
|
||||
host_bash_config = None
|
||||
if not is_host_bash_allowed(host_bash_config):
|
||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||
if subagent_type == "bash" and not is_host_bash_allowed(resolve_context(runtime).app_config):
|
||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||
|
||||
# Build config overrides
|
||||
overrides: dict = {}
|
||||
|
||||
@ -68,6 +68,28 @@ def provisioner_module():
|
||||
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_app_config():
|
||||
"""Initialize a minimal AppConfig for tests so ``AppConfig.current()`` never tries to auto-load config.yaml.
|
||||
|
||||
Individual tests can still override via ``patch.object(AppConfig, "current", ...)``
|
||||
or by calling ``AppConfig.init()`` with a different config.
|
||||
"""
|
||||
try:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
except ImportError:
|
||||
yield
|
||||
return
|
||||
|
||||
previous_global = AppConfig._global
|
||||
AppConfig._global = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AppConfig._global = previous_global
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_user_context(request):
|
||||
"""Inject a default ``test-user-autouse`` into the contextvar.
|
||||
|
||||
@ -1,173 +1,169 @@
|
||||
"""Core behavior tests for TitleMiddleware."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
|
||||
def _make_config(**title_overrides) -> AppConfig:
|
||||
return AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
|
||||
def _make_title_config(**overrides) -> TitleConfig:
|
||||
return TitleConfig(**overrides)
|
||||
|
||||
|
||||
def _patch_app_config(**title_overrides):
|
||||
return patch.object(AppConfig, "current", return_value=_make_config(**title_overrides))
|
||||
def _make_runtime(**title_overrides) -> SimpleNamespace:
|
||||
"""Build a runtime whose context carries a DeerFlowContext with the given TitleConfig."""
|
||||
app_config = AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
|
||||
ctx = DeerFlowContext(app_config=app_config, thread_id="t1")
|
||||
return SimpleNamespace(context=ctx)
|
||||
|
||||
|
||||
class TestTitleMiddlewareCoreLogic:
|
||||
def test_should_generate_title_for_first_complete_exchange(self):
|
||||
with _patch_app_config(enabled=True):
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="帮我总结这段代码"),
|
||||
AIMessage(content="好的,我先看结构"),
|
||||
]
|
||||
}
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="帮我总结这段代码"),
|
||||
AIMessage(content="好的,我先看结构"),
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is True
|
||||
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is True
|
||||
|
||||
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
with _patch_app_config(enabled=False):
|
||||
disabled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": None,
|
||||
}
|
||||
assert middleware._should_generate_title(disabled_state) is False
|
||||
disabled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": None,
|
||||
}
|
||||
assert middleware._should_generate_title(disabled_state, _make_title_config(enabled=False)) is False
|
||||
|
||||
with _patch_app_config(enabled=True):
|
||||
titled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": "Existing Title",
|
||||
}
|
||||
assert middleware._should_generate_title(titled_state) is False
|
||||
titled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": "Existing Title",
|
||||
}
|
||||
assert middleware._should_generate_title(titled_state, _make_title_config(enabled=True)) is False
|
||||
|
||||
def test_should_not_generate_title_after_second_user_turn(self):
|
||||
with _patch_app_config(enabled=True):
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="第一问"),
|
||||
AIMessage(content="第一答"),
|
||||
HumanMessage(content="第二问"),
|
||||
AIMessage(content="第二答"),
|
||||
]
|
||||
}
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="第一问"),
|
||||
AIMessage(content="第一答"),
|
||||
HumanMessage(content="第二问"),
|
||||
AIMessage(content="第二答"),
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is False
|
||||
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is False
|
||||
|
||||
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
||||
with _patch_app_config(max_chars=12):
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
|
||||
AIMessage(content="好的,先确认需求"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
title = result["title"]
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
|
||||
AIMessage(content="好的,先确认需求"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=12)))
|
||||
title = result["title"]
|
||||
|
||||
assert title == "短标题"
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert title == "短标题"
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||
with _patch_app_config(max_chars=20):
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
|
||||
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
|
||||
]
|
||||
}
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
|
||||
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
|
||||
]
|
||||
}
|
||||
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
title = result["title"]
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
|
||||
title = result["title"]
|
||||
|
||||
assert title == "请帮我总结这段代码"
|
||||
assert title == "请帮我总结这段代码"
|
||||
|
||||
def test_generate_title_fallback_for_long_message(self, monkeypatch):
|
||||
with _patch_app_config(max_chars=20):
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题"),
|
||||
AIMessage(content="收到"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
title = result["title"]
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题"),
|
||||
AIMessage(content="收到"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
|
||||
title = result["title"]
|
||||
|
||||
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
||||
assert title.endswith("...")
|
||||
assert title.startswith("这是一个非常长的问题描述")
|
||||
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
||||
assert title.endswith("...")
|
||||
assert title.startswith("这是一个非常长的问题描述")
|
||||
|
||||
def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
|
||||
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
|
||||
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime()))
|
||||
assert result == {"title": "异步标题"}
|
||||
|
||||
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
|
||||
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
|
||||
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime())) is None
|
||||
|
||||
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
|
||||
result = middleware.after_model({"messages": []}, runtime=MagicMock())
|
||||
result = middleware.after_model({"messages": []}, runtime=_make_runtime())
|
||||
assert result == {"title": "同步标题"}
|
||||
|
||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
|
||||
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
|
||||
assert middleware.after_model({"messages": []}, runtime=_make_runtime()) is None
|
||||
|
||||
def test_sync_generate_title_uses_fallback_without_model(self):
|
||||
"""Sync path avoids LLM calls and derives a local fallback title."""
|
||||
with _patch_app_config(max_chars=20):
|
||||
middleware = TitleMiddleware()
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="请帮我写测试"),
|
||||
AIMessage(content="好的"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state)
|
||||
assert result == {"title": "请帮我写测试"}
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="请帮我写测试"),
|
||||
AIMessage(content="好的"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state, _make_title_config(max_chars=20))
|
||||
assert result == {"title": "请帮我写测试"}
|
||||
|
||||
def test_sync_generate_title_respects_fallback_truncation(self):
|
||||
"""Sync fallback path still respects max_chars truncation rules."""
|
||||
with _patch_app_config(max_chars=50):
|
||||
middleware = TitleMiddleware()
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
|
||||
AIMessage(content="回复"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state)
|
||||
assert result["title"].endswith("...")
|
||||
assert result["title"].startswith("这是一个非常长的问题描述")
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
|
||||
AIMessage(content="回复"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state, _make_title_config(max_chars=50))
|
||||
assert result["title"].endswith("...")
|
||||
assert result["title"].startswith("这是一个非常长的问题描述")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user