Hello world
" + + mock_config = MagicMock() + mock_config.get_tool_config.return_value = None + monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config) + monkeypatch.setattr(JinaClient, "crawl", mock_crawl) + result = await web_fetch_tool.ainvoke("https://example.com") + assert "Hello world" in result + assert not result.startswith("Error:") diff --git a/backend/tests/test_lead_agent_skills.py b/backend/tests/test_lead_agent_skills.py new file mode 100644 index 000000000..37a6dbff8 --- /dev/null +++ b/backend/tests/test_lead_agent_skills.py @@ -0,0 +1,96 @@ +from pathlib import Path + +from deerflow.agents.lead_agent.prompt import get_skills_prompt_section +from deerflow.config.agents_config import AgentConfig +from deerflow.skills.types import Skill + + +def _make_skill(name: str) -> Skill: + return Skill( + name=name, + description=f"Description for {name}", + license="MIT", + skill_dir=Path(f"/tmp/{name}"), + skill_file=Path(f"/tmp/{name}/SKILL.md"), + relative_path=Path(name), + category="public", + enabled=True, + ) + + +def test_get_skills_prompt_section_returns_empty_when_no_skills_match(monkeypatch): + skills = [_make_skill("skill1"), _make_skill("skill2")] + monkeypatch.setattr("deerflow.agents.lead_agent.prompt.load_skills", lambda enabled_only: skills) + + result = get_skills_prompt_section(available_skills={"non_existent_skill"}) + assert result == "" + + +def test_get_skills_prompt_section_returns_empty_when_available_skills_empty(monkeypatch): + skills = [_make_skill("skill1"), _make_skill("skill2")] + monkeypatch.setattr("deerflow.agents.lead_agent.prompt.load_skills", lambda enabled_only: skills) + + result = get_skills_prompt_section(available_skills=set()) + assert result == "" + + +def test_get_skills_prompt_section_returns_skills(monkeypatch): + skills = [_make_skill("skill1"), _make_skill("skill2")] + monkeypatch.setattr("deerflow.agents.lead_agent.prompt.load_skills", lambda enabled_only: skills) + + result = get_skills_prompt_section(available_skills={"skill1"}) + assert "skill1" in result + assert "skill2" not in result + + +def test_get_skills_prompt_section_returns_all_when_available_skills_is_none(monkeypatch): + skills = [_make_skill("skill1"), _make_skill("skill2")] + monkeypatch.setattr("deerflow.agents.lead_agent.prompt.load_skills", lambda enabled_only: skills) + + result = get_skills_prompt_section(available_skills=None) + assert "skill1" in result + assert "skill2" in result + + +def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch): + from unittest.mock import MagicMock + + from deerflow.agents.lead_agent import agent as lead_agent_module + + # Mock dependencies + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock()) + monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model") + monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model") + monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) + monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: []) + monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs) + + class MockModelConfig: + supports_thinking = False + + mock_app_config = MagicMock() + mock_app_config.get_model_config.return_value = MockModelConfig() + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config) + + captured_skills = [] + + def mock_apply_prompt_template(**kwargs): + captured_skills.append(kwargs.get("available_skills")) + return "mock_prompt" + + monkeypatch.setattr(lead_agent_module, "apply_prompt_template", mock_apply_prompt_template) + + # Case 1: Empty skills list + monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=[])) + lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}) + assert captured_skills[-1] == set() + + # Case 2: None skills list + monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None)) + lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}) + assert captured_skills[-1] is None + + # Case 3: Some skills list + monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"])) + lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}}) + assert captured_skills[-1] == {"skill1"} diff --git a/backend/tests/test_llm_error_handling_middleware.py b/backend/tests/test_llm_error_handling_middleware.py new file mode 100644 index 000000000..9c3077e31 --- /dev/null +++ b/backend/tests/test_llm_error_handling_middleware.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest +from langchain_core.messages import AIMessage +from langgraph.errors import GraphBubbleUp + +from deerflow.agents.middlewares.llm_error_handling_middleware import ( + LLMErrorHandlingMiddleware, +) + + +class FakeError(Exception): + def __init__( + self, + message: str, + *, + status_code: int | None = None, + code: str | None = None, + headers: dict[str, str] | None = None, + body: dict | None = None, + ) -> None: + super().__init__(message) + self.status_code = status_code + self.code = code + self.body = body + self.response = SimpleNamespace(status_code=status_code, headers=headers or {}) if status_code is not None or headers else None + + +def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware: + middleware = LLMErrorHandlingMiddleware() + for key, value in attrs.items(): + setattr(middleware, key, value) + return middleware + + +def test_async_model_call_retries_busy_provider_then_succeeds( + monkeypatch: pytest.MonkeyPatch, +) -> None: + middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=25, retry_cap_delay_ms=25) + attempts = 0 + waits: list[float] = [] + events: list[dict] = [] + + async def fake_sleep(delay: float) -> None: + waits.append(delay) + + def fake_writer(): + return events.append + + async def handler(_request) -> AIMessage: + nonlocal attempts + attempts += 1 + if attempts < 3: + raise FakeError("当前服务集群负载较高,请稍后重试,感谢您的耐心等待。 (2064)") + return AIMessage(content="ok") + + monkeypatch.setattr("asyncio.sleep", fake_sleep) + monkeypatch.setattr( + "langgraph.config.get_stream_writer", + fake_writer, + ) + + result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) + + assert isinstance(result, AIMessage) + assert result.content == "ok" + assert attempts == 3 + assert waits == [0.025, 0.025] + assert [event["type"] for event in events] == ["llm_retry", "llm_retry"] + + +def test_async_model_call_returns_user_message_for_quota_errors() -> None: + middleware = _build_middleware(retry_max_attempts=3) + + async def handler(_request) -> AIMessage: + raise FakeError( + "insufficient_quota: account balance is empty", + status_code=429, + code="insufficient_quota", + ) + + result = asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) + + assert isinstance(result, AIMessage) + assert "out of quota" in str(result.content) + + +def test_sync_model_call_uses_retry_after_header(monkeypatch: pytest.MonkeyPatch) -> None: + middleware = _build_middleware(retry_max_attempts=2, retry_base_delay_ms=10, retry_cap_delay_ms=10) + waits: list[float] = [] + attempts = 0 + + def fake_sleep(delay: float) -> None: + waits.append(delay) + + def handler(_request) -> AIMessage: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise FakeError( + "server busy", + status_code=503, + headers={"Retry-After": "2"}, + ) + return AIMessage(content="ok") + + monkeypatch.setattr("time.sleep", fake_sleep) + + result = middleware.wrap_model_call(SimpleNamespace(), handler) + + assert isinstance(result, AIMessage) + assert result.content == "ok" + assert waits == [2.0] + + +def test_sync_model_call_propagates_graph_bubble_up() -> None: + middleware = _build_middleware() + + def handler(_request) -> AIMessage: + raise GraphBubbleUp() + + with pytest.raises(GraphBubbleUp): + middleware.wrap_model_call(SimpleNamespace(), handler) + + +def test_async_model_call_propagates_graph_bubble_up() -> None: + middleware = _build_middleware() + + async def handler(_request) -> AIMessage: + raise GraphBubbleUp() + + with pytest.raises(GraphBubbleUp): + asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) diff --git a/backend/tests/test_local_sandbox_provider_mounts.py b/backend/tests/test_local_sandbox_provider_mounts.py new file mode 100644 index 000000000..0eb6d4654 --- /dev/null +++ b/backend/tests/test_local_sandbox_provider_mounts.py @@ -0,0 +1,388 @@ +import errno +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping +from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider + + +class TestPathMapping: + def test_path_mapping_dataclass(self): + mapping = PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True) + assert mapping.container_path == "/mnt/skills" + assert mapping.local_path == "/home/user/skills" + assert mapping.read_only is True + + def test_path_mapping_defaults_to_false(self): + mapping = PathMapping(container_path="/mnt/data", local_path="/home/user/data") + assert mapping.read_only is False + + +class TestLocalSandboxPathResolution: + def test_resolve_path_exact_match(self): + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"), + ], + ) + resolved = sandbox._resolve_path("/mnt/skills") + assert resolved == "/home/user/skills" + + def test_resolve_path_nested_path(self): + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"), + ], + ) + resolved = sandbox._resolve_path("/mnt/skills/agent/prompt.py") + assert resolved == "/home/user/skills/agent/prompt.py" + + def test_resolve_path_no_mapping(self): + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"), + ], + ) + resolved = sandbox._resolve_path("/mnt/other/file.txt") + assert resolved == "/mnt/other/file.txt" + + def test_resolve_path_longest_prefix_first(self): + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"), + PathMapping(container_path="/mnt", local_path="/var/mnt"), + ], + ) + resolved = sandbox._resolve_path("/mnt/skills/file.py") + # Should match /mnt/skills first (longer prefix) + assert resolved == "/home/user/skills/file.py" + + def test_reverse_resolve_path_exact_match(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)), + ], + ) + resolved = sandbox._reverse_resolve_path(str(skills_dir)) + assert resolved == "/mnt/skills" + + def test_reverse_resolve_path_nested(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + file_path = skills_dir / "agent" / "prompt.py" + file_path.parent.mkdir() + file_path.write_text("test") + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)), + ], + ) + resolved = sandbox._reverse_resolve_path(str(file_path)) + assert resolved == "/mnt/skills/agent/prompt.py" + + +class TestReadOnlyPath: + def test_is_read_only_true(self): + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True), + ], + ) + assert sandbox._is_read_only_path("/home/user/skills/file.py") is True + + def test_is_read_only_false_for_writable(self): + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/data", local_path="/home/user/data", read_only=False), + ], + ) + assert sandbox._is_read_only_path("/home/user/data/file.txt") is False + + def test_is_read_only_false_for_unmapped_path(self): + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True), + ], + ) + # Path not under any mapping + assert sandbox._is_read_only_path("/tmp/other/file.txt") is False + + def test_is_read_only_true_for_exact_match(self): + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True), + ], + ) + assert sandbox._is_read_only_path("/home/user/skills") is True + + def test_write_file_blocked_on_read_only(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True), + ], + ) + # Skills dir is read-only, write should be blocked + with pytest.raises(OSError) as exc_info: + sandbox.write_file("/mnt/skills/new_file.py", "content") + assert exc_info.value.errno == errno.EROFS + + def test_write_file_allowed_on_writable_mount(self, tmp_path): + data_dir = tmp_path / "data" + data_dir.mkdir() + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False), + ], + ) + sandbox.write_file("/mnt/data/file.txt", "content") + assert (data_dir / "file.txt").read_text() == "content" + + def test_update_file_blocked_on_read_only(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + existing_file = skills_dir / "existing.py" + existing_file.write_bytes(b"original") + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True), + ], + ) + with pytest.raises(OSError) as exc_info: + sandbox.update_file("/mnt/skills/existing.py", b"updated") + assert exc_info.value.errno == errno.EROFS + + +class TestMultipleMounts: + def test_multiple_read_write_mounts(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + external_dir = tmp_path / "external" + external_dir.mkdir() + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True), + PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False), + PathMapping(container_path="/mnt/external", local_path=str(external_dir), read_only=True), + ], + ) + + # Skills is read-only + with pytest.raises(OSError): + sandbox.write_file("/mnt/skills/file.py", "content") + + # Data is writable + sandbox.write_file("/mnt/data/file.txt", "data content") + assert (data_dir / "file.txt").read_text() == "data content" + + # External is read-only + with pytest.raises(OSError): + sandbox.write_file("/mnt/external/file.txt", "content") + + def test_nested_mounts_writable_under_readonly(self, tmp_path): + """A writable mount nested under a read-only mount should allow writes.""" + ro_dir = tmp_path / "ro" + ro_dir.mkdir() + rw_dir = ro_dir / "writable" + rw_dir.mkdir() + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/repo", local_path=str(ro_dir), read_only=True), + PathMapping(container_path="/mnt/repo/writable", local_path=str(rw_dir), read_only=False), + ], + ) + + # Parent mount is read-only + with pytest.raises(OSError): + sandbox.write_file("/mnt/repo/file.txt", "content") + + # Nested writable mount should allow writes + sandbox.write_file("/mnt/repo/writable/file.txt", "content") + assert (rw_dir / "file.txt").read_text() == "content" + + def test_execute_command_path_replacement(self, tmp_path, monkeypatch): + data_dir = tmp_path / "data" + data_dir.mkdir() + test_file = data_dir / "test.txt" + test_file.write_text("hello") + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/data", local_path=str(data_dir)), + ], + ) + + # Mock subprocess to capture the resolved command + captured = {} + original_run = __import__("subprocess").run + + def mock_run(*args, **kwargs): + if len(args) > 0: + captured["command"] = args[0] + return original_run(*args, **kwargs) + + monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.subprocess.run", mock_run) + monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.LocalSandbox._get_shell", lambda self: "/bin/sh") + + sandbox.execute_command("cat /mnt/data/test.txt") + # Verify the command received the resolved local path + assert str(data_dir) in captured.get("command", "") + + def test_reverse_resolve_path_does_not_match_partial_prefix(self, tmp_path): + foo_dir = tmp_path / "foo" + foo_dir.mkdir() + foobar_dir = tmp_path / "foobar" + foobar_dir.mkdir() + target = foobar_dir / "file.txt" + target.write_text("test") + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/foo", local_path=str(foo_dir)), + ], + ) + + resolved = sandbox._reverse_resolve_path(str(target)) + assert resolved == str(target.resolve()) + + def test_reverse_resolve_paths_in_output_supports_backslash_separator(self, tmp_path): + mount_dir = tmp_path / "mount" + mount_dir.mkdir() + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/data", local_path=str(mount_dir)), + ], + ) + + output = f"Copied: {mount_dir}\\file.txt" + masked = sandbox._reverse_resolve_paths_in_output(output) + + assert "/mnt/data/file.txt" in masked + assert str(mount_dir) not in masked + + +class TestLocalSandboxProviderMounts: + def test_setup_path_mappings_uses_configured_skills_container_path_as_reserved_prefix(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + custom_dir = tmp_path / "custom" + custom_dir.mkdir() + + from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig + + sandbox_config = SandboxConfig( + use="deerflow.sandbox.local:LocalSandboxProvider", + mounts=[ + VolumeMountConfig(host_path=str(custom_dir), container_path="/custom-skills/nested", read_only=False), + ], + ) + config = SimpleNamespace( + skills=SimpleNamespace(container_path="/custom-skills", get_skills_path=lambda: skills_dir), + sandbox=sandbox_config, + ) + + with patch("deerflow.config.get_app_config", return_value=config): + provider = LocalSandboxProvider() + + assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"] + + def test_setup_path_mappings_skips_relative_host_path(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + + from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig + + sandbox_config = SandboxConfig( + use="deerflow.sandbox.local:LocalSandboxProvider", + mounts=[ + VolumeMountConfig(host_path="relative/path", container_path="/mnt/data", read_only=False), + ], + ) + config = SimpleNamespace( + skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir), + sandbox=sandbox_config, + ) + + with patch("deerflow.config.get_app_config", return_value=config): + provider = LocalSandboxProvider() + + assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"] + + def test_setup_path_mappings_skips_non_absolute_container_path(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + custom_dir = tmp_path / "custom" + custom_dir.mkdir() + + from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig + + sandbox_config = SandboxConfig( + use="deerflow.sandbox.local:LocalSandboxProvider", + mounts=[ + VolumeMountConfig(host_path=str(custom_dir), container_path="mnt/data", read_only=False), + ], + ) + config = SimpleNamespace( + skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir), + sandbox=sandbox_config, + ) + + with patch("deerflow.config.get_app_config", return_value=config): + provider = LocalSandboxProvider() + + assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"] + + def test_setup_path_mappings_normalizes_container_path_trailing_slash(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + custom_dir = tmp_path / "custom" + custom_dir.mkdir() + + from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig + + sandbox_config = SandboxConfig( + use="deerflow.sandbox.local:LocalSandboxProvider", + mounts=[ + VolumeMountConfig(host_path=str(custom_dir), container_path="/mnt/data/", read_only=False), + ], + ) + config = SimpleNamespace( + skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir), + sandbox=sandbox_config, + ) + + with patch("deerflow.config.get_app_config", return_value=config): + provider = LocalSandboxProvider() + + assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"] diff --git a/backend/tests/test_loop_detection_middleware.py b/backend/tests/test_loop_detection_middleware.py index 3bd0c3665..e037b8492 100644 --- a/backend/tests/test_loop_detection_middleware.py +++ b/backend/tests/test_loop_detection_middleware.py @@ -1,5 +1,6 @@ """Tests for LoopDetectionMiddleware.""" +import copy from unittest.mock import MagicMock from langchain_core.messages import AIMessage, HumanMessage, SystemMessage @@ -19,8 +20,13 @@ def _make_runtime(thread_id="test-thread"): def _make_state(tool_calls=None, content=""): - """Build a minimal AgentState dict with an AIMessage.""" - msg = AIMessage(content=content, tool_calls=tool_calls or []) + """Build a minimal AgentState dict with an AIMessage. + + Deep-copies *content* when it is mutable (e.g. list) so that + successive calls never share the same object reference. + """ + safe_content = copy.deepcopy(content) if isinstance(content, list) else content + msg = AIMessage(content=safe_content, tool_calls=tool_calls or []) return {"messages": [msg]} @@ -229,3 +235,114 @@ class TestLoopDetection: mw._apply(_make_state(tool_calls=call), runtime) assert "default" in mw._history + + +class TestAppendText: + """Unit tests for LoopDetectionMiddleware._append_text.""" + + def test_none_content_returns_text(self): + result = LoopDetectionMiddleware._append_text(None, "hello") + assert result == "hello" + + def test_str_content_concatenates(self): + result = LoopDetectionMiddleware._append_text("existing", "appended") + assert result == "existing\n\nappended" + + def test_empty_str_content_concatenates(self): + result = LoopDetectionMiddleware._append_text("", "appended") + assert result == "\n\nappended" + + def test_list_content_appends_text_block(self): + """List content (e.g. Anthropic thinking mode) should get a new text block.""" + content = [ + {"type": "thinking", "text": "Let me think..."}, + {"type": "text", "text": "Here is my answer"}, + ] + result = LoopDetectionMiddleware._append_text(content, "stop msg") + assert isinstance(result, list) + assert len(result) == 3 + assert result[0] == content[0] + assert result[1] == content[1] + assert result[2] == {"type": "text", "text": "\n\nstop msg"} + + def test_empty_list_content_appends_text_block(self): + result = LoopDetectionMiddleware._append_text([], "stop msg") + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] == {"type": "text", "text": "\n\nstop msg"} + + def test_unexpected_type_coerced_to_str(self): + """Unexpected content types should be coerced to str as a fallback.""" + result = LoopDetectionMiddleware._append_text(42, "stop msg") + assert isinstance(result, str) + assert result == "42\n\nstop msg" + + def test_list_content_not_mutated_in_place(self): + """_append_text must not modify the original list.""" + original = [{"type": "text", "text": "hello"}] + result = LoopDetectionMiddleware._append_text(original, "appended") + assert len(original) == 1 # original unchanged + assert len(result) == 2 # new list has the appended block + + +class TestHardStopWithListContent: + """Regression tests: hard stop must not crash when AIMessage.content is a list.""" + + def test_hard_stop_with_list_content(self): + """Hard stop on list content should not raise TypeError (regression).""" + mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) + runtime = _make_runtime() + call = [_bash_call("ls")] + + # Build state with list content (e.g. Anthropic thinking mode) + list_content = [ + {"type": "thinking", "text": "Let me think..."}, + {"type": "text", "text": "I'll run ls"}, + ] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call, content=list_content), runtime) + + # Fourth call triggers hard stop — must not raise TypeError + result = mw._apply(_make_state(tool_calls=call, content=list_content), runtime) + assert result is not None + msg = result["messages"][0] + assert isinstance(msg, AIMessage) + assert msg.tool_calls == [] + # Content should remain a list with the stop message appended + assert isinstance(msg.content, list) + assert len(msg.content) == 3 + assert msg.content[2]["type"] == "text" + assert _HARD_STOP_MSG in msg.content[2]["text"] + + def test_hard_stop_with_none_content(self): + """Hard stop on None content should produce a plain string.""" + mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) + runtime = _make_runtime() + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime) + + # Fourth call with default empty-string content + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is not None + msg = result["messages"][0] + assert isinstance(msg.content, str) + assert _HARD_STOP_MSG in msg.content + + def test_hard_stop_with_str_content(self): + """Hard stop on str content should concatenate the stop message.""" + mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) + runtime = _make_runtime() + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime) + + result = mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime) + assert result is not None + msg = result["messages"][0] + assert isinstance(msg.content, str) + assert msg.content.startswith("thinking...") + assert _HARD_STOP_MSG in msg.content diff --git a/backend/tests/test_memory_prompt_injection.py b/backend/tests/test_memory_prompt_injection.py index ab1f0a783..7c3ad85c4 100644 --- a/backend/tests/test_memory_prompt_injection.py +++ b/backend/tests/test_memory_prompt_injection.py @@ -119,3 +119,57 @@ def test_format_memory_skips_non_string_content_facts() -> None: # The formatted line for a list content would be "- [knowledge | 0.85] ['list']". assert "| 0.85]" not in result assert "Valid fact" in result + + +def test_format_memory_renders_correction_source_error() -> None: + memory_data = { + "facts": [ + { + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "sourceError": "The agent previously suggested npm start.", + } + ] + } + + result = format_memory_for_injection(memory_data, max_tokens=2000) + + assert "Use make dev for local development." in result + assert "avoid: The agent previously suggested npm start." in result + + +def test_format_memory_renders_correction_without_source_error_normally() -> None: + memory_data = { + "facts": [ + { + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + } + ] + } + + result = format_memory_for_injection(memory_data, max_tokens=2000) + + assert "Use make dev for local development." in result + assert "avoid:" not in result + + +def test_format_memory_includes_long_term_background() -> None: + """longTermBackground in history must be injected into the prompt.""" + memory_data = { + "user": {}, + "history": { + "recentMonths": {"summary": "Recent activity summary"}, + "earlierContext": {"summary": "Earlier context summary"}, + "longTermBackground": {"summary": "Core expertise in distributed systems"}, + }, + "facts": [], + } + + result = format_memory_for_injection(memory_data, max_tokens=2000) + + assert "Background: Core expertise in distributed systems" in result + assert "Recent: Recent activity summary" in result + assert "Earlier: Earlier context summary" in result diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py new file mode 100644 index 000000000..6ef91a142 --- /dev/null +++ b/backend/tests/test_memory_queue.py @@ -0,0 +1,50 @@ +from unittest.mock import MagicMock, patch + +from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue +from deerflow.config.memory_config import MemoryConfig + + +def _memory_config(**overrides: object) -> MemoryConfig: + config = MemoryConfig() + for key, value in overrides.items(): + setattr(config, key, value) + return config + + +def test_queue_add_preserves_existing_correction_flag_for_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["first"], correction_detected=True) + queue.add(thread_id="thread-1", messages=["second"], correction_detected=False) + + assert len(queue._queue) == 1 + assert queue._queue[0].messages == ["second"] + assert queue._queue[0].correction_detected is True + + +def test_process_queue_forwards_correction_flag_to_updater() -> None: + queue = MemoryUpdateQueue() + queue._queue = [ + ConversationContext( + thread_id="thread-1", + messages=["conversation"], + agent_name="lead_agent", + correction_detected=True, + ) + ] + mock_updater = MagicMock() + mock_updater.update_memory.return_value = True + + with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater): + queue._process_queue() + + mock_updater.update_memory.assert_called_once_with( + messages=["conversation"], + thread_id="thread-1", + agent_name="lead_agent", + correction_detected=True, + ) diff --git a/backend/tests/test_memory_router.py b/backend/tests/test_memory_router.py index 39134c61d..23a4f30fe 100644 --- a/backend/tests/test_memory_router.py +++ b/backend/tests/test_memory_router.py @@ -72,6 +72,56 @@ def test_import_memory_route_returns_imported_memory() -> None: assert response.json()["facts"] == imported_memory["facts"] +def test_export_memory_route_preserves_source_error() -> None: + app = FastAPI() + app.include_router(memory.router) + exported_memory = _sample_memory( + facts=[ + { + "id": "fact_correction", + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "createdAt": "2026-03-20T00:00:00Z", + "source": "thread-1", + "sourceError": "The agent previously suggested npm start.", + } + ] + ) + + with patch("app.gateway.routers.memory.get_memory_data", return_value=exported_memory): + with TestClient(app) as client: + response = client.get("/api/memory/export") + + assert response.status_code == 200 + assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start." + + +def test_import_memory_route_preserves_source_error() -> None: + app = FastAPI() + app.include_router(memory.router) + imported_memory = _sample_memory( + facts=[ + { + "id": "fact_correction", + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "createdAt": "2026-03-20T00:00:00Z", + "source": "thread-1", + "sourceError": "The agent previously suggested npm start.", + } + ] + ) + + with patch("app.gateway.routers.memory.import_memory_data", return_value=imported_memory): + with TestClient(app) as client: + response = client.post("/api/memory/import", json=imported_memory) + + assert response.status_code == 200 + assert response.json()["facts"][0]["sourceError"] == "The agent previously suggested npm start." + + def test_clear_memory_route_returns_cleared_memory() -> None: app = FastAPI() app.include_router(memory.router) diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index f7b48228a..6309cf9f6 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -146,6 +146,53 @@ def test_apply_updates_preserves_threshold_and_max_facts_trimming() -> None: assert result["facts"][1]["source"] == "thread-9" +def test_apply_updates_preserves_source_error() -> None: + updater = MemoryUpdater() + current_memory = _make_memory() + update_data = { + "newFacts": [ + { + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "sourceError": "The agent previously suggested npm start.", + } + ] + } + + with patch( + "deerflow.agents.memory.updater.get_memory_config", + return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7), + ): + result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction") + + assert result["facts"][0]["sourceError"] == "The agent previously suggested npm start." + assert result["facts"][0]["category"] == "correction" + + +def test_apply_updates_ignores_empty_source_error() -> None: + updater = MemoryUpdater() + current_memory = _make_memory() + update_data = { + "newFacts": [ + { + "content": "Use make dev for local development.", + "category": "correction", + "confidence": 0.95, + "sourceError": " ", + } + ] + } + + with patch( + "deerflow.agents.memory.updater.get_memory_config", + return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7), + ): + result = updater._apply_updates(current_memory, update_data, thread_id="thread-correction") + + assert "sourceError" not in result["facts"][0] + + def test_clear_memory_data_resets_all_sections() -> None: with patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True): result = clear_memory_data() @@ -522,3 +569,53 @@ class TestUpdateMemoryStructuredResponse: result = updater.update_memory([msg, ai_msg]) assert result is True + + def test_correction_hint_injected_when_detected(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "No, that's wrong." + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Understood" + ai_msg.tool_calls = [] + + result = updater.update_memory([msg, ai_msg], correction_detected=True) + + assert result is True + prompt = model.invoke.call_args[0][0] + assert "Explicit correction signals were detected" in prompt + + def test_correction_hint_empty_when_not_detected(self): + updater = MemoryUpdater() + valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' + model = self._make_mock_model(valid_json) + + with ( + patch.object(updater, "_get_model", return_value=model), + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "Let's talk about memory." + ai_msg = MagicMock() + ai_msg.type = "ai" + ai_msg.content = "Sure" + ai_msg.tool_calls = [] + + result = updater.update_memory([msg, ai_msg], correction_detected=False) + + assert result is True + prompt = model.invoke.call_args[0][0] + assert "Explicit correction signals were detected" not in prompt diff --git a/backend/tests/test_memory_upload_filtering.py b/backend/tests/test_memory_upload_filtering.py index 45d0dbf4e..1ff0aa3b6 100644 --- a/backend/tests/test_memory_upload_filtering.py +++ b/backend/tests/test_memory_upload_filtering.py @@ -10,7 +10,7 @@ persisting in long-term memory: from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory -from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory +from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction # --------------------------------------------------------------------------- # Helpers @@ -134,6 +134,64 @@ class TestFilterMessagesForMemory: assert "{nameError}
- )} + ) : null}