mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(config): simplify runtime context access in middlewares and tools
- title_middleware: drop the _resolve_title_config() try/except wrapper and the optional title_config=None fallback on every helper. after_model/aafter_model read runtime.context.app_config.title directly; helpers take TitleConfig as a required parameter. Matches the typed Runtime[DeerFlowContext] signature. - memory_middleware: drop resolve_context() call; use runtime.context directly since the type is already declared. - sandbox/tools.py: drop three layers of try/except Exception around resolve_context(runtime).app_config.sandbox. If the config can't be resolved that's a real bug that should surface, not be swallowed with a default. - task_tool.py: same — drop the try/except around resolve_context(). - client.py: drop the set_override() call in __init__ and _reload_config(). It was leaking overrides across test boundaries and the leak-free path (init() alone) is enough for the single-Client case. - conftest: autouse fixture that initializes a minimal AppConfig for every test so AppConfig.current() doesn't try to auto-load config.yaml. - test_title_middleware_core_logic: pass TitleConfig explicitly to helpers instead of patching AppConfig.current globally.
This commit is contained in:
parent
faec3bf9f2
commit
a934a822df
@ -9,7 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
|
|||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.agents.memory.queue import get_memory_queue
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
|
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -203,12 +203,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||||||
Returns:
|
Returns:
|
||||||
None (no state changes needed from this middleware).
|
None (no state changes needed from this middleware).
|
||||||
"""
|
"""
|
||||||
ctx = resolve_context(runtime)
|
memory_config = runtime.context.app_config.memory
|
||||||
memory_config = ctx.app_config.memory
|
|
||||||
if not memory_config.enabled:
|
if not memory_config.enabled:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
thread_id = ctx.thread_id
|
thread_id = runtime.context.thread_id
|
||||||
if not thread_id:
|
if not thread_id:
|
||||||
logger.debug("No thread_id in context, skipping memory update")
|
logger.debug("No thread_id in context, skipping memory update")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -8,8 +8,7 @@ from langchain.agents.middleware import AgentMiddleware
|
|||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.config.app_config import AppConfig
|
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||||
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
|
|
||||||
from deerflow.config.title_config import TitleConfig
|
from deerflow.config.title_config import TitleConfig
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
@ -46,10 +45,8 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> bool:
|
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig) -> bool:
|
||||||
"""Check if we should generate a title for this thread."""
|
"""Check if we should generate a title for this thread."""
|
||||||
if title_config is None:
|
|
||||||
title_config = AppConfig.current().title
|
|
||||||
if not title_config.enabled:
|
if not title_config.enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -69,13 +66,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
# Generate title after first complete exchange
|
# Generate title after first complete exchange
|
||||||
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
||||||
|
|
||||||
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> tuple[str, str]:
|
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig) -> tuple[str, str]:
|
||||||
"""Extract user/assistant messages and build the title prompt.
|
"""Extract user/assistant messages and build the title prompt.
|
||||||
|
|
||||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||||
"""
|
"""
|
||||||
if title_config is None:
|
|
||||||
title_config = AppConfig.current().title
|
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
|
|
||||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||||
@ -91,17 +86,13 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
)
|
)
|
||||||
return prompt, user_msg
|
return prompt, user_msg
|
||||||
|
|
||||||
def _parse_title(self, content: object, title_config: TitleConfig | None = None) -> str:
|
def _parse_title(self, content: object, title_config: TitleConfig) -> str:
|
||||||
"""Normalize model output into a clean title string."""
|
"""Normalize model output into a clean title string."""
|
||||||
if title_config is None:
|
|
||||||
title_config = AppConfig.current().title
|
|
||||||
title_content = self._normalize_content(content)
|
title_content = self._normalize_content(content)
|
||||||
title = title_content.strip().strip('"').strip("'")
|
title = title_content.strip().strip('"').strip("'")
|
||||||
return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
|
return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
|
||||||
|
|
||||||
def _fallback_title(self, user_msg: str, title_config: TitleConfig | None = None) -> str:
|
def _fallback_title(self, user_msg: str, title_config: TitleConfig) -> str:
|
||||||
if title_config is None:
|
|
||||||
title_config = AppConfig.current().title
|
|
||||||
fallback_chars = min(title_config.max_chars, 50)
|
fallback_chars = min(title_config.max_chars, 50)
|
||||||
if len(user_msg) > fallback_chars:
|
if len(user_msg) > fallback_chars:
|
||||||
return user_msg[:fallback_chars].rstrip() + "..."
|
return user_msg[:fallback_chars].rstrip() + "..."
|
||||||
@ -121,20 +112,16 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None:
|
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
|
||||||
"""Generate a local fallback title without blocking on an LLM call."""
|
"""Generate a local fallback title without blocking on an LLM call."""
|
||||||
if title_config is None:
|
|
||||||
title_config = AppConfig.current().title
|
|
||||||
if not self._should_generate_title(state, title_config):
|
if not self._should_generate_title(state, title_config):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
_, user_msg = self._build_title_prompt(state, title_config)
|
_, user_msg = self._build_title_prompt(state, title_config)
|
||||||
return {"title": self._fallback_title(user_msg, title_config)}
|
return {"title": self._fallback_title(user_msg, title_config)}
|
||||||
|
|
||||||
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig | None = None) -> dict | None:
|
async def _agenerate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
|
||||||
"""Generate a title asynchronously and fall back locally on failure."""
|
"""Generate a title asynchronously and fall back locally on failure."""
|
||||||
if title_config is None:
|
|
||||||
title_config = AppConfig.current().title
|
|
||||||
if not self._should_generate_title(state, title_config):
|
if not self._should_generate_title(state, title_config):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -153,21 +140,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
|||||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||||
return {"title": self._fallback_title(user_msg, title_config)}
|
return {"title": self._fallback_title(user_msg, title_config)}
|
||||||
|
|
||||||
def _resolve_title_config(self, runtime: Runtime[DeerFlowContext]) -> TitleConfig | None:
|
|
||||||
"""Resolve TitleConfig from the runtime context when possible.
|
|
||||||
|
|
||||||
Returns None on any failure so callers can fall back to
|
|
||||||
``AppConfig.current()`` without raising.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return resolve_context(runtime).app_config.title
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||||
return self._generate_title_result(state, self._resolve_title_config(runtime))
|
return self._generate_title_result(state, runtime.context.app_config.title)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||||
return await self._agenerate_title_result(state, self._resolve_title_config(runtime))
|
return await self._agenerate_title_result(state, runtime.context.app_config.title)
|
||||||
|
|||||||
@ -143,12 +143,8 @@ class DeerFlowClient:
|
|||||||
middlewares: Optional list of custom middlewares to inject into the agent.
|
middlewares: Optional list of custom middlewares to inject into the agent.
|
||||||
"""
|
"""
|
||||||
if config_path is not None:
|
if config_path is not None:
|
||||||
config = AppConfig.from_file(config_path)
|
AppConfig.init(AppConfig.from_file(config_path))
|
||||||
AppConfig.init(config)
|
|
||||||
self._app_config = AppConfig.current()
|
self._app_config = AppConfig.current()
|
||||||
# Scope this client's config to the current context so it doesn't
|
|
||||||
# leak into unrelated async tasks when multiple clients coexist.
|
|
||||||
self._config_token = AppConfig.set_override(self._app_config)
|
|
||||||
|
|
||||||
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
|
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
|
||||||
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
|
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
|
||||||
@ -177,11 +173,9 @@ class DeerFlowClient:
|
|||||||
self._agent_config_key = None
|
self._agent_config_key = None
|
||||||
|
|
||||||
def _reload_config(self) -> None:
|
def _reload_config(self) -> None:
|
||||||
"""Reload config from file, update both process-global and context override."""
|
"""Reload config from file and refresh the cached reference."""
|
||||||
config = AppConfig.from_file()
|
AppConfig.init(AppConfig.from_file())
|
||||||
AppConfig.init(config)
|
self._app_config = AppConfig.current()
|
||||||
self._app_config = config
|
|
||||||
self._config_token = AppConfig.set_override(config)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
|
|||||||
@ -987,12 +987,11 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
sandbox = ensure_sandbox_initialized(runtime)
|
sandbox = ensure_sandbox_initialized(runtime)
|
||||||
|
app_config = resolve_context(runtime).app_config
|
||||||
|
sandbox_cfg = app_config.sandbox
|
||||||
|
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||||
if is_local_sandbox(runtime):
|
if is_local_sandbox(runtime):
|
||||||
try:
|
if not is_host_bash_allowed(app_config):
|
||||||
host_bash_config = resolve_context(runtime).app_config
|
|
||||||
except Exception:
|
|
||||||
host_bash_config = None
|
|
||||||
if not is_host_bash_allowed(host_bash_config):
|
|
||||||
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
|
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
|
||||||
ensure_thread_directories_exist(runtime)
|
ensure_thread_directories_exist(runtime)
|
||||||
thread_data = get_thread_data(runtime)
|
thread_data = get_thread_data(runtime)
|
||||||
@ -1000,18 +999,8 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
|||||||
command = replace_virtual_paths_in_command(command, thread_data)
|
command = replace_virtual_paths_in_command(command, thread_data)
|
||||||
command = _apply_cwd_prefix(command, thread_data)
|
command = _apply_cwd_prefix(command, thread_data)
|
||||||
output = sandbox.execute_command(command)
|
output = sandbox.execute_command(command)
|
||||||
try:
|
|
||||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
|
||||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
|
||||||
except Exception:
|
|
||||||
max_chars = 20000
|
|
||||||
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
|
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
|
||||||
ensure_thread_directories_exist(runtime)
|
ensure_thread_directories_exist(runtime)
|
||||||
try:
|
|
||||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
|
||||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
|
||||||
except Exception:
|
|
||||||
max_chars = 20000
|
|
||||||
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
|
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
|
||||||
except SandboxError as e:
|
except SandboxError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
@ -1047,11 +1036,8 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
|||||||
if not children:
|
if not children:
|
||||||
return "(empty)"
|
return "(empty)"
|
||||||
output = "\n".join(children)
|
output = "\n".join(children)
|
||||||
try:
|
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
|
||||||
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
|
|
||||||
except Exception:
|
|
||||||
max_chars = 20000
|
|
||||||
return _truncate_ls_output(output, max_chars)
|
return _truncate_ls_output(output, max_chars)
|
||||||
except SandboxError as e:
|
except SandboxError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
@ -1218,11 +1204,8 @@ def read_file_tool(
|
|||||||
return "(empty)"
|
return "(empty)"
|
||||||
if start_line is not None and end_line is not None:
|
if start_line is not None and end_line is not None:
|
||||||
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
||||||
try:
|
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
||||||
sandbox_cfg = resolve_context(runtime).app_config.sandbox
|
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
|
||||||
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
|
|
||||||
except Exception:
|
|
||||||
max_chars = 50000
|
|
||||||
return _truncate_read_file_output(content, max_chars)
|
return _truncate_read_file_output(content, max_chars)
|
||||||
except SandboxError as e:
|
except SandboxError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|||||||
@ -67,13 +67,8 @@ async def task_tool(
|
|||||||
if config is None:
|
if config is None:
|
||||||
available = ", ".join(available_subagent_names)
|
available = ", ".join(available_subagent_names)
|
||||||
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
|
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
|
||||||
if subagent_type == "bash":
|
if subagent_type == "bash" and not is_host_bash_allowed(resolve_context(runtime).app_config):
|
||||||
try:
|
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||||
host_bash_config = resolve_context(runtime).app_config
|
|
||||||
except Exception:
|
|
||||||
host_bash_config = None
|
|
||||||
if not is_host_bash_allowed(host_bash_config):
|
|
||||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
|
||||||
|
|
||||||
# Build config overrides
|
# Build config overrides
|
||||||
overrides: dict = {}
|
overrides: dict = {}
|
||||||
|
|||||||
@ -68,6 +68,28 @@ def provisioner_module():
|
|||||||
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _auto_app_config():
|
||||||
|
"""Initialize a minimal AppConfig for tests so ``AppConfig.current()`` never tries to auto-load config.yaml.
|
||||||
|
|
||||||
|
Individual tests can still override via ``patch.object(AppConfig, "current", ...)``
|
||||||
|
or by calling ``AppConfig.init()`` with a different config.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from deerflow.config.app_config import AppConfig
|
||||||
|
from deerflow.config.sandbox_config import SandboxConfig
|
||||||
|
except ImportError:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
previous_global = AppConfig._global
|
||||||
|
AppConfig._global = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
AppConfig._global = previous_global
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _auto_user_context(request):
|
def _auto_user_context(request):
|
||||||
"""Inject a default ``test-user-autouse`` into the contextvar.
|
"""Inject a default ``test-user-autouse`` into the contextvar.
|
||||||
|
|||||||
@ -1,173 +1,169 @@
|
|||||||
"""Core behavior tests for TitleMiddleware."""
|
"""Core behavior tests for TitleMiddleware."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
||||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||||
from deerflow.config.app_config import AppConfig
|
from deerflow.config.app_config import AppConfig
|
||||||
|
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||||
from deerflow.config.sandbox_config import SandboxConfig
|
from deerflow.config.sandbox_config import SandboxConfig
|
||||||
from deerflow.config.title_config import TitleConfig
|
from deerflow.config.title_config import TitleConfig
|
||||||
|
|
||||||
|
|
||||||
def _make_config(**title_overrides) -> AppConfig:
|
def _make_title_config(**overrides) -> TitleConfig:
|
||||||
return AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
|
return TitleConfig(**overrides)
|
||||||
|
|
||||||
|
|
||||||
def _patch_app_config(**title_overrides):
|
def _make_runtime(**title_overrides) -> SimpleNamespace:
|
||||||
return patch.object(AppConfig, "current", return_value=_make_config(**title_overrides))
|
"""Build a runtime whose context carries a DeerFlowContext with the given TitleConfig."""
|
||||||
|
app_config = AppConfig(sandbox=SandboxConfig(use="test"), title=TitleConfig(**title_overrides))
|
||||||
|
ctx = DeerFlowContext(app_config=app_config, thread_id="t1")
|
||||||
|
return SimpleNamespace(context=ctx)
|
||||||
|
|
||||||
|
|
||||||
class TestTitleMiddlewareCoreLogic:
|
class TestTitleMiddlewareCoreLogic:
|
||||||
def test_should_generate_title_for_first_complete_exchange(self):
|
def test_should_generate_title_for_first_complete_exchange(self):
|
||||||
with _patch_app_config(enabled=True):
|
middleware = TitleMiddleware()
|
||||||
middleware = TitleMiddleware()
|
state = {
|
||||||
state = {
|
"messages": [
|
||||||
"messages": [
|
HumanMessage(content="帮我总结这段代码"),
|
||||||
HumanMessage(content="帮我总结这段代码"),
|
AIMessage(content="好的,我先看结构"),
|
||||||
AIMessage(content="好的,我先看结构"),
|
]
|
||||||
]
|
}
|
||||||
}
|
|
||||||
|
|
||||||
assert middleware._should_generate_title(state) is True
|
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is True
|
||||||
|
|
||||||
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
||||||
middleware = TitleMiddleware()
|
middleware = TitleMiddleware()
|
||||||
|
|
||||||
with _patch_app_config(enabled=False):
|
disabled_state = {
|
||||||
disabled_state = {
|
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
"title": None,
|
||||||
"title": None,
|
}
|
||||||
}
|
assert middleware._should_generate_title(disabled_state, _make_title_config(enabled=False)) is False
|
||||||
assert middleware._should_generate_title(disabled_state) is False
|
|
||||||
|
|
||||||
with _patch_app_config(enabled=True):
|
titled_state = {
|
||||||
titled_state = {
|
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
"title": "Existing Title",
|
||||||
"title": "Existing Title",
|
}
|
||||||
}
|
assert middleware._should_generate_title(titled_state, _make_title_config(enabled=True)) is False
|
||||||
assert middleware._should_generate_title(titled_state) is False
|
|
||||||
|
|
||||||
def test_should_not_generate_title_after_second_user_turn(self):
|
def test_should_not_generate_title_after_second_user_turn(self):
|
||||||
with _patch_app_config(enabled=True):
|
middleware = TitleMiddleware()
|
||||||
middleware = TitleMiddleware()
|
state = {
|
||||||
state = {
|
"messages": [
|
||||||
"messages": [
|
HumanMessage(content="第一问"),
|
||||||
HumanMessage(content="第一问"),
|
AIMessage(content="第一答"),
|
||||||
AIMessage(content="第一答"),
|
HumanMessage(content="第二问"),
|
||||||
HumanMessage(content="第二问"),
|
AIMessage(content="第二答"),
|
||||||
AIMessage(content="第二答"),
|
]
|
||||||
]
|
}
|
||||||
}
|
|
||||||
|
|
||||||
assert middleware._should_generate_title(state) is False
|
assert middleware._should_generate_title(state, _make_title_config(enabled=True)) is False
|
||||||
|
|
||||||
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
||||||
with _patch_app_config(max_chars=12):
|
middleware = TitleMiddleware()
|
||||||
middleware = TitleMiddleware()
|
model = MagicMock()
|
||||||
model = MagicMock()
|
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
|
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
|
||||||
AIMessage(content="好的,先确认需求"),
|
AIMessage(content="好的,先确认需求"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=12)))
|
||||||
title = result["title"]
|
title = result["title"]
|
||||||
|
|
||||||
assert title == "短标题"
|
assert title == "短标题"
|
||||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||||
model.ainvoke.assert_awaited_once()
|
model.ainvoke.assert_awaited_once()
|
||||||
|
|
||||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||||
with _patch_app_config(max_chars=20):
|
middleware = TitleMiddleware()
|
||||||
middleware = TitleMiddleware()
|
model = MagicMock()
|
||||||
model = MagicMock()
|
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
|
||||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
|
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
|
HumanMessage(content=[{"type": "text", "text": "请帮我总结这段代码"}]),
|
||||||
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
|
AIMessage(content=[{"type": "text", "text": "好的,先看结构"}]),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
|
||||||
title = result["title"]
|
title = result["title"]
|
||||||
|
|
||||||
assert title == "请帮我总结这段代码"
|
assert title == "请帮我总结这段代码"
|
||||||
|
|
||||||
def test_generate_title_fallback_for_long_message(self, monkeypatch):
|
def test_generate_title_fallback_for_long_message(self, monkeypatch):
|
||||||
with _patch_app_config(max_chars=20):
|
middleware = TitleMiddleware()
|
||||||
middleware = TitleMiddleware()
|
model = MagicMock()
|
||||||
model = MagicMock()
|
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
|
||||||
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
|
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题"),
|
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题"),
|
||||||
AIMessage(content="收到"),
|
AIMessage(content="收到"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
result = asyncio.run(middleware._agenerate_title_result(state, _make_title_config(max_chars=20)))
|
||||||
title = result["title"]
|
title = result["title"]
|
||||||
|
|
||||||
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
||||||
assert title.endswith("...")
|
assert title.endswith("...")
|
||||||
assert title.startswith("这是一个非常长的问题描述")
|
assert title.startswith("这是一个非常长的问题描述")
|
||||||
|
|
||||||
def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
|
def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
|
||||||
middleware = TitleMiddleware()
|
middleware = TitleMiddleware()
|
||||||
|
|
||||||
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
|
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
|
||||||
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
|
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime()))
|
||||||
assert result == {"title": "异步标题"}
|
assert result == {"title": "异步标题"}
|
||||||
|
|
||||||
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
|
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
|
||||||
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
|
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=_make_runtime())) is None
|
||||||
|
|
||||||
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
|
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
|
||||||
middleware = TitleMiddleware()
|
middleware = TitleMiddleware()
|
||||||
|
|
||||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
|
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
|
||||||
result = middleware.after_model({"messages": []}, runtime=MagicMock())
|
result = middleware.after_model({"messages": []}, runtime=_make_runtime())
|
||||||
assert result == {"title": "同步标题"}
|
assert result == {"title": "同步标题"}
|
||||||
|
|
||||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
|
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
|
||||||
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
|
assert middleware.after_model({"messages": []}, runtime=_make_runtime()) is None
|
||||||
|
|
||||||
def test_sync_generate_title_uses_fallback_without_model(self):
|
def test_sync_generate_title_uses_fallback_without_model(self):
|
||||||
"""Sync path avoids LLM calls and derives a local fallback title."""
|
"""Sync path avoids LLM calls and derives a local fallback title."""
|
||||||
with _patch_app_config(max_chars=20):
|
middleware = TitleMiddleware()
|
||||||
middleware = TitleMiddleware()
|
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
HumanMessage(content="请帮我写测试"),
|
HumanMessage(content="请帮我写测试"),
|
||||||
AIMessage(content="好的"),
|
AIMessage(content="好的"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = middleware._generate_title_result(state)
|
result = middleware._generate_title_result(state, _make_title_config(max_chars=20))
|
||||||
assert result == {"title": "请帮我写测试"}
|
assert result == {"title": "请帮我写测试"}
|
||||||
|
|
||||||
def test_sync_generate_title_respects_fallback_truncation(self):
|
def test_sync_generate_title_respects_fallback_truncation(self):
|
||||||
"""Sync fallback path still respects max_chars truncation rules."""
|
"""Sync fallback path still respects max_chars truncation rules."""
|
||||||
with _patch_app_config(max_chars=50):
|
middleware = TitleMiddleware()
|
||||||
middleware = TitleMiddleware()
|
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
|
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
|
||||||
AIMessage(content="回复"),
|
AIMessage(content="回复"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = middleware._generate_title_result(state)
|
result = middleware._generate_title_result(state, _make_title_config(max_chars=50))
|
||||||
assert result["title"].endswith("...")
|
assert result["title"].endswith("...")
|
||||||
assert result["title"].startswith("这是一个非常长的问题描述")
|
assert result["title"].startswith("这是一个非常长的问题描述")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user