diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index d48630f37..f66d6f2e5 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,4 +1,4 @@ -"""Test configuration for the backend test suite. +"""Test configuration shared by unit and end-to-end tests. Sets up sys.path and pre-mocks modules that would cause circular import issues when unit-testing lightweight config/registry code in isolation. @@ -6,8 +6,8 @@ issues when unit-testing lightweight config/registry code in isolation. import importlib.util import sys +import warnings from pathlib import Path -from types import SimpleNamespace from unittest.mock import MagicMock import pytest @@ -38,14 +38,51 @@ _executor_mock.get_background_task_result = MagicMock() sys.modules["deerflow.subagents.executor"] = _executor_mock +def pytest_configure(config): + config.addinivalue_line( + "markers", + "no_auto_user: disable the conftest autouse contextvar fixture for this test", + ) + warnings.filterwarnings( + "ignore", + message=r"Pydantic serializer warnings:.*field_name='context'.*", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"pkg_resources is deprecated as an API.*", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"Deprecated call to `pkg_resources\.declare_namespace\(.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"datetime\.datetime\.utcfromtimestamp\(\) is deprecated.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"websockets\.InvalidStatusCode is deprecated", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"websockets\.legacy is deprecated.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"websockets\.client\.WebSocketClientProtocol is deprecated", + category=DeprecationWarning, + ) + + @pytest.fixture() def provisioner_module(): - """Load docker/provisioner/app.py as an importable test module. - - Shared by test_provisioner_kubeconfig and test_provisioner_pvc_volumes so - that any change to the provisioner entry-point path or module name only - needs to be updated in one place. - """ + """Load docker/provisioner/app.py as an importable test module.""" repo_root = Path(__file__).resolve().parents[2] module_path = repo_root / "docker" / "provisioner" / "app.py" spec = importlib.util.spec_from_file_location("provisioner_app_test", module_path) @@ -56,42 +93,25 @@ def provisioner_module(): return module -# --------------------------------------------------------------------------- -# Auto-set user context for every test unless marked no_auto_user -# --------------------------------------------------------------------------- -# -# Repository methods read ``user_id`` from a contextvar by default -# (see ``deerflow.runtime.user_context``). Without this fixture, every -# pre-existing persistence test would raise RuntimeError because the -# contextvar is unset. The fixture sets a default test user on every -# test; tests that explicitly want to verify behaviour *without* a user -# context should mark themselves ``@pytest.mark.no_auto_user``. - - @pytest.fixture(autouse=True) def _auto_user_context(request): - """Inject a default ``test-user-autouse`` into the contextvar. - - Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that - tests which don't touch the persistence layer never pay the cost - of importing runtime.user_context. - """ + """Inject a default ``test-user-autouse`` into the contextvar.""" if request.node.get_closest_marker("no_auto_user"): yield return try: - from deerflow.runtime.user_context import ( - reset_current_user, - set_current_user, + from deerflow.runtime.actor_context import ( + ActorContext, + bind_actor_context, + reset_actor_context, ) except ImportError: yield return - user = SimpleNamespace(id="test-user-autouse", email="test@local") - token = set_current_user(user) + token = bind_actor_context(ActorContext(user_id="test-user-autouse")) try: yield finally: - reset_current_user(token) + reset_actor_context(token) diff --git a/backend/tests/e2e/conftest.py b/backend/tests/e2e/conftest.py new file mode 100644 index 000000000..f85767c9f --- /dev/null +++ b/backend/tests/e2e/conftest.py @@ -0,0 +1,33 @@ +"""Shared fixtures for end-to-end API tests.""" + +import pytest +from fastapi.testclient import TestClient + +from app.plugins.auth.api.schemas import _login_attempts +from app.plugins.auth.domain.config import AuthConfig +from app.plugins.auth.runtime.config_state import reset_auth_config, set_auth_config +from store.config.app_config import AppConfig, reset_app_config, set_app_config +from store.config.storage_config import StorageConfig + +_TEST_SECRET = "test-secret-key-e2e-auth-minimum-32" + + +@pytest.fixture() +def client(tmp_path): + """Create a full app client backed by an isolated SQLite directory.""" + from app.gateway.app import create_app + + _login_attempts.clear() + reset_auth_config() + reset_app_config() + set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) + set_app_config(AppConfig(storage=StorageConfig(driver="sqlite", sqlite_dir=str(tmp_path)))) + + app = create_app() + + with TestClient(app) as test_client: + yield test_client + + _login_attempts.clear() + reset_auth_config() + reset_app_config() diff --git a/backend/tests/e2e/test_auth_initialize_me.py b/backend/tests/e2e/test_auth_initialize_me.py new file mode 100644 index 000000000..cda72edc1 --- /dev/null +++ b/backend/tests/e2e/test_auth_initialize_me.py @@ -0,0 +1,163 @@ +"""End-to-end auth API tests for the main auth user journeys.""" + +from app.plugins.auth.security.csrf import CSRF_HEADER_NAME + + +def _initialize_payload(**overrides): + return { + "email": "admin@example.com", + "password": "Str0ng!Pass99", + **overrides, + } + + +def _register_payload(**overrides): + return { + "email": "user@example.com", + "password": "Str0ng!Pass99", + **overrides, + } + + +def _login(client, *, email="user@example.com", password="Str0ng!Pass99"): + return client.post( + "/api/v1/auth/login/local", + data={"username": email, "password": password}, + ) + + +def _csrf_headers(client) -> dict[str, str]: + token = client.cookies.get("csrf_token") + assert token, "csrf_token cookie is required before calling protected POST endpoints" + return {CSRF_HEADER_NAME: token} + + +def test_initialize_returns_admin_and_sets_session_cookie(client): + response = client.post("/api/v1/auth/initialize", json=_initialize_payload()) + + assert response.status_code == 201 + assert response.json()["email"] == "admin@example.com" + assert response.json()["system_role"] == "admin" + assert "access_token" in response.cookies + assert "access_token" in client.cookies + + +def test_me_returns_initialized_admin_identity(client): + initialize = client.post("/api/v1/auth/initialize", json=_initialize_payload()) + assert initialize.status_code == 201 + + response = client.get("/api/v1/auth/me") + + assert response.status_code == 200 + assert response.json() == { + "id": response.json()["id"], + "email": "admin@example.com", + "system_role": "admin", + "needs_setup": False, + } + + +def test_setup_status_flips_after_initialize(client): + before = client.get("/api/v1/auth/setup-status") + assert before.status_code == 200 + assert before.json() == {"needs_setup": True} + + initialize = client.post("/api/v1/auth/initialize", json=_initialize_payload()) + assert initialize.status_code == 201 + + after = client.get("/api/v1/auth/setup-status") + assert after.status_code == 200 + assert after.json() == {"needs_setup": False} + + +def test_register_logs_in_user_and_me_returns_identity(client): + response = client.post("/api/v1/auth/register", json=_register_payload()) + + assert response.status_code == 201 + assert response.json()["email"] == "user@example.com" + assert response.json()["system_role"] == "user" + assert "access_token" in client.cookies + assert "csrf_token" in client.cookies + + me = client.get("/api/v1/auth/me") + assert me.status_code == 200 + assert me.json()["email"] == "user@example.com" + assert me.json()["system_role"] == "user" + assert me.json()["needs_setup"] is False + + +def test_me_requires_authentication(client): + response = client.get("/api/v1/auth/me") + + assert response.status_code == 401 + assert response.json()["detail"]["code"] == "not_authenticated" + + +def test_logout_clears_session_and_me_is_denied(client): + register = client.post("/api/v1/auth/register", json=_register_payload()) + assert register.status_code == 201 + + logout = client.post("/api/v1/auth/logout") + assert logout.status_code == 200 + assert logout.json() == {"message": "Successfully logged out"} + + me = client.get("/api/v1/auth/me") + assert me.status_code == 401 + assert me.json()["detail"]["code"] == "not_authenticated" + + +def test_login_local_restores_session_after_logout(client): + register = client.post("/api/v1/auth/register", json=_register_payload()) + assert register.status_code == 201 + assert client.post("/api/v1/auth/logout").status_code == 200 + + login = _login(client) + assert login.status_code == 200 + assert login.json()["needs_setup"] is False + assert "access_token" in client.cookies + assert "csrf_token" in client.cookies + + me = client.get("/api/v1/auth/me") + assert me.status_code == 200 + assert me.json()["email"] == "user@example.com" + + +def test_change_password_updates_credentials_and_rotates_login(client): + register = client.post("/api/v1/auth/register", json=_register_payload()) + assert register.status_code == 201 + + change = client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "Str0ng!Pass99", + "new_password": "An0ther!Pass88", + "new_email": "renamed@example.com", + }, + headers=_csrf_headers(client), + ) + assert change.status_code == 200 + assert change.json() == {"message": "Password changed successfully"} + + assert client.post("/api/v1/auth/logout").status_code == 200 + + old_login = _login(client) + assert old_login.status_code == 401 + assert old_login.json()["detail"]["code"] == "invalid_credentials" + + new_login = _login(client, email="renamed@example.com", password="An0ther!Pass88") + assert new_login.status_code == 200 + + me = client.get("/api/v1/auth/me") + assert me.status_code == 200 + assert me.json()["email"] == "renamed@example.com" + + +def test_oauth_endpoints_expose_current_placeholder_behavior(client): + unsupported = client.get("/api/v1/auth/oauth/not-a-provider") + assert unsupported.status_code == 400 + + github = client.get("/api/v1/auth/oauth/github") + assert github.status_code == 501 + + callback = client.get("/api/v1/auth/callback/github", params={"code": "abc", "state": "xyz"}) + assert callback.status_code == 501 diff --git a/backend/tests/test_checkpointer.py b/backend/tests/test_checkpointer.py deleted file mode 100644 index 58f57237e..000000000 --- a/backend/tests/test_checkpointer.py +++ /dev/null @@ -1,304 +0,0 @@ -"""Unit tests for checkpointer config and singleton factory.""" - -import sys -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -import deerflow.config.app_config as app_config_module -from deerflow.config.checkpointer_config import ( - CheckpointerConfig, - get_checkpointer_config, - load_checkpointer_config_from_dict, - set_checkpointer_config, -) -from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer - - -@pytest.fixture(autouse=True) -def reset_state(): - """Reset singleton state before each test.""" - app_config_module._app_config = None - set_checkpointer_config(None) - reset_checkpointer() - yield - app_config_module._app_config = None - set_checkpointer_config(None) - reset_checkpointer() - - -# --------------------------------------------------------------------------- -# Config tests -# --------------------------------------------------------------------------- - - -class TestCheckpointerConfig: - def test_load_memory_config(self): - load_checkpointer_config_from_dict({"type": "memory"}) - config = get_checkpointer_config() - assert config is not None - assert config.type == "memory" - assert config.connection_string is None - - def test_load_sqlite_config(self): - load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"}) - config = get_checkpointer_config() - assert config is not None - assert config.type == "sqlite" - assert config.connection_string == "/tmp/test.db" - - def test_load_postgres_config(self): - load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"}) - config = get_checkpointer_config() - assert config is not None - assert config.type == "postgres" - assert config.connection_string == "postgresql://localhost/db" - - def test_default_connection_string_is_none(self): - config = CheckpointerConfig(type="memory") - assert config.connection_string is None - - def test_set_config_to_none(self): - load_checkpointer_config_from_dict({"type": "memory"}) - set_checkpointer_config(None) - assert get_checkpointer_config() is None - - def test_invalid_type_raises(self): - with pytest.raises(Exception): - load_checkpointer_config_from_dict({"type": "unknown"}) - - -# --------------------------------------------------------------------------- -# Factory tests -# --------------------------------------------------------------------------- - - -class TestGetCheckpointer: - def test_returns_in_memory_saver_when_not_configured(self): - """get_checkpointer should return InMemorySaver when not configured.""" - from langgraph.checkpoint.memory import InMemorySaver - - with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError): - cp = get_checkpointer() - assert cp is not None - assert isinstance(cp, InMemorySaver) - - def test_memory_returns_in_memory_saver(self): - load_checkpointer_config_from_dict({"type": "memory"}) - from langgraph.checkpoint.memory import InMemorySaver - - cp = get_checkpointer() - assert isinstance(cp, InMemorySaver) - - def test_memory_singleton(self): - load_checkpointer_config_from_dict({"type": "memory"}) - cp1 = get_checkpointer() - cp2 = get_checkpointer() - assert cp1 is cp2 - - def test_reset_clears_singleton(self): - load_checkpointer_config_from_dict({"type": "memory"}) - cp1 = get_checkpointer() - reset_checkpointer() - cp2 = get_checkpointer() - assert cp1 is not cp2 - - def test_sqlite_raises_when_package_missing(self): - load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"}) - with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}): - reset_checkpointer() - with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"): - get_checkpointer() - - def test_postgres_raises_when_package_missing(self): - load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"}) - with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}): - reset_checkpointer() - with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"): - get_checkpointer() - - def test_postgres_raises_when_connection_string_missing(self): - load_checkpointer_config_from_dict({"type": "postgres"}) - mock_saver = MagicMock() - mock_module = MagicMock() - mock_module.PostgresSaver = mock_saver - with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}): - reset_checkpointer() - with pytest.raises(ValueError, match="connection_string is required"): - get_checkpointer() - - def test_sqlite_creates_saver(self): - """SQLite checkpointer is created when package is available.""" - load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"}) - - mock_saver_instance = MagicMock() - mock_cm = MagicMock() - mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance) - mock_cm.__exit__ = MagicMock(return_value=False) - - mock_saver_cls = MagicMock() - mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm) - - mock_module = MagicMock() - mock_module.SqliteSaver = mock_saver_cls - - with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}): - reset_checkpointer() - cp = get_checkpointer() - - assert cp is mock_saver_instance - mock_saver_cls.from_conn_string.assert_called_once() - mock_saver_instance.setup.assert_called_once() - - def test_postgres_creates_saver(self): - """Postgres checkpointer is created when packages are available.""" - load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"}) - - mock_saver_instance = MagicMock() - mock_cm = MagicMock() - mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance) - mock_cm.__exit__ = MagicMock(return_value=False) - - mock_saver_cls = MagicMock() - mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm) - - mock_pg_module = MagicMock() - mock_pg_module.PostgresSaver = mock_saver_cls - - with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}): - reset_checkpointer() - cp = get_checkpointer() - - assert cp is mock_saver_instance - mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db") - mock_saver_instance.setup.assert_called_once() - - -class TestAsyncCheckpointer: - @pytest.mark.anyio - async def test_sqlite_creates_parent_dir_via_to_thread(self): - """Async SQLite setup should move mkdir off the event loop.""" - from deerflow.runtime.checkpointer.async_provider import make_checkpointer - - mock_config = MagicMock() - mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db") - - mock_saver = AsyncMock() - mock_cm = AsyncMock() - mock_cm.__aenter__.return_value = mock_saver - mock_cm.__aexit__.return_value = False - - mock_saver_cls = MagicMock() - mock_saver_cls.from_conn_string.return_value = mock_cm - - mock_module = MagicMock() - mock_module.AsyncSqliteSaver = mock_saver_cls - - with ( - patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config), - patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}), - patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread, - patch( - "deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str", - return_value="/tmp/resolved/test.db", - ), - ): - async with make_checkpointer() as saver: - assert saver is mock_saver - - mock_to_thread.assert_awaited_once() - called_fn, called_path = mock_to_thread.await_args.args - assert called_fn.__name__ == "ensure_sqlite_parent_dir" - assert called_path == "/tmp/resolved/test.db" - mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db") - mock_saver.setup.assert_awaited_once() - - -# --------------------------------------------------------------------------- -# app_config.py integration -# --------------------------------------------------------------------------- - - -class TestAppConfigLoadsCheckpointer: - def test_load_checkpointer_section(self): - """load_checkpointer_config_from_dict populates the global config.""" - set_checkpointer_config(None) - load_checkpointer_config_from_dict({"type": "memory"}) - cfg = get_checkpointer_config() - assert cfg is not None - assert cfg.type == "memory" - - -# --------------------------------------------------------------------------- -# DeerFlowClient falls back to config checkpointer -# --------------------------------------------------------------------------- - - -class TestClientCheckpointerFallback: - def test_client_uses_config_checkpointer_when_none_provided(self): - """DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None.""" - from langgraph.checkpoint.memory import InMemorySaver - - from deerflow.client import DeerFlowClient - - load_checkpointer_config_from_dict({"type": "memory"}) - - captured_kwargs = {} - - def fake_create_agent(**kwargs): - captured_kwargs.update(kwargs) - return MagicMock() - - model_mock = MagicMock() - config_mock = MagicMock() - config_mock.models = [model_mock] - config_mock.get_model_config.return_value = MagicMock(supports_vision=False) - config_mock.checkpointer = None - - with ( - patch("deerflow.client.get_app_config", return_value=config_mock), - patch("deerflow.client.create_agent", side_effect=fake_create_agent), - patch("deerflow.client.create_chat_model", return_value=MagicMock()), - patch("deerflow.client._build_middlewares", return_value=[]), - patch("deerflow.client.apply_prompt_template", return_value=""), - patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]), - ): - client = DeerFlowClient(checkpointer=None) - config = client._get_runnable_config("test-thread") - client._ensure_agent(config) - - assert "checkpointer" in captured_kwargs - assert isinstance(captured_kwargs["checkpointer"], InMemorySaver) - - def test_client_explicit_checkpointer_takes_precedence(self): - """An explicitly provided checkpointer is used even when config checkpointer is set.""" - from deerflow.client import DeerFlowClient - - load_checkpointer_config_from_dict({"type": "memory"}) - - explicit_cp = MagicMock() - captured_kwargs = {} - - def fake_create_agent(**kwargs): - captured_kwargs.update(kwargs) - return MagicMock() - - model_mock = MagicMock() - config_mock = MagicMock() - config_mock.models = [model_mock] - config_mock.get_model_config.return_value = MagicMock(supports_vision=False) - config_mock.checkpointer = None - - with ( - patch("deerflow.client.get_app_config", return_value=config_mock), - patch("deerflow.client.create_agent", side_effect=fake_create_agent), - patch("deerflow.client.create_chat_model", return_value=MagicMock()), - patch("deerflow.client._build_middlewares", return_value=[]), - patch("deerflow.client.apply_prompt_template", return_value=""), - patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]), - ): - client = DeerFlowClient(checkpointer=explicit_cp) - config = client._get_runnable_config("test-thread") - client._ensure_agent(config) - - assert captured_kwargs["checkpointer"] is explicit_cp diff --git a/backend/tests/test_checkpointer_none_fix.py b/backend/tests/test_checkpointer_none_fix.py deleted file mode 100644 index 3c7a25fa1..000000000 --- a/backend/tests/test_checkpointer_none_fix.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Test for issue #1016: checkpointer should not return None.""" - -from unittest.mock import MagicMock, patch - -import pytest -from langgraph.checkpoint.memory import InMemorySaver - - -class TestCheckpointerNoneFix: - """Tests that checkpointer context managers return InMemorySaver instead of None.""" - - @pytest.mark.anyio - async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self): - """make_checkpointer should return InMemorySaver when config.checkpointer is None.""" - from deerflow.runtime.checkpointer.async_provider import make_checkpointer - - # Mock get_app_config to return a config with checkpointer=None and database=None - mock_config = MagicMock() - mock_config.checkpointer = None - mock_config.database = None - - with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config): - async with make_checkpointer() as checkpointer: - # Should return InMemorySaver, not None - assert checkpointer is not None - assert isinstance(checkpointer, InMemorySaver) - - # Should be able to call alist() without AttributeError - # This is what LangGraph does and what was failing in issue #1016 - result = [] - async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}): - result.append(item) - - # Empty list is expected for a fresh checkpointer - assert result == [] - - def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self): - """checkpointer_context should return InMemorySaver when config.checkpointer is None.""" - from deerflow.runtime.checkpointer.provider import checkpointer_context - - # Mock get_app_config to return a config with checkpointer=None - mock_config = MagicMock() - mock_config.checkpointer = None - - with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config): - with checkpointer_context() as checkpointer: - # Should return InMemorySaver, not None - assert checkpointer is not None - assert isinstance(checkpointer, InMemorySaver) - - # Should be able to call list() without AttributeError - result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}})) - - # Empty list is expected for a fresh checkpointer - assert result == [] diff --git a/backend/tests/test_ensure_admin.py b/backend/tests/test_ensure_admin.py deleted file mode 100644 index 9930b047f..000000000 --- a/backend/tests/test_ensure_admin.py +++ /dev/null @@ -1,296 +0,0 @@ -"""Tests for _ensure_admin_user() in app.py. - -Covers: first-boot no-op (admin creation removed), orphan migration -when admin exists, no-op on no admin found, and edge cases. -""" - -import asyncio -import os -from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32") - -from app.gateway.auth.config import AuthConfig, set_auth_config - -_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32" - - -@pytest.fixture(autouse=True) -def _setup_auth_config(): - set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET)) - yield - set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET)) - - -def _make_app_stub(store=None): - """Minimal app-like object with state.store.""" - app = SimpleNamespace() - app.state = SimpleNamespace() - app.state.store = store - return app - - -def _make_provider(admin_count=0): - p = AsyncMock() - p.count_users = AsyncMock(return_value=admin_count) - p.count_admin_users = AsyncMock(return_value=admin_count) - p.create_user = AsyncMock() - p.update_user = AsyncMock(side_effect=lambda u: u) - return p - - -def _make_session_factory(admin_row=None): - """Build a mock async session factory that returns a row from execute().""" - row_result = MagicMock() - row_result.scalar_one_or_none.return_value = admin_row - - execute_result = MagicMock() - execute_result.scalar_one_or_none.return_value = admin_row - - session = AsyncMock() - session.execute = AsyncMock(return_value=execute_result) - - # Async context manager - session_cm = AsyncMock() - session_cm.__aenter__ = AsyncMock(return_value=session) - session_cm.__aexit__ = AsyncMock(return_value=False) - - sf = MagicMock() - sf.return_value = session_cm - return sf - - -# ── First boot: no admin → return early ────────────────────────────────── - - -def test_first_boot_does_not_create_admin(): - """admin_count==0 → do NOT create admin automatically.""" - provider = _make_provider(admin_count=0) - app = _make_app_stub() - - with patch("app.gateway.deps.get_local_provider", return_value=provider): - from app.gateway.app import _ensure_admin_user - - asyncio.run(_ensure_admin_user(app)) - - provider.create_user.assert_not_called() - - -def test_first_boot_skips_migration(): - """No admin → return early before any migration attempt.""" - provider = _make_provider(admin_count=0) - store = AsyncMock() - store.asearch = AsyncMock(return_value=[]) - app = _make_app_stub(store=store) - - with patch("app.gateway.deps.get_local_provider", return_value=provider): - from app.gateway.app import _ensure_admin_user - - asyncio.run(_ensure_admin_user(app)) - - store.asearch.assert_not_called() - - -# ── Admin exists: migration runs when admin row found ──────────────────── - - -def test_admin_exists_triggers_migration(): - """Admin exists and admin row found → _migrate_orphaned_threads called.""" - from uuid import uuid4 - - admin_row = MagicMock() - admin_row.id = uuid4() - - provider = _make_provider(admin_count=1) - sf = _make_session_factory(admin_row=admin_row) - store = AsyncMock() - store.asearch = AsyncMock(return_value=[]) - app = _make_app_stub(store=store) - - with patch("app.gateway.deps.get_local_provider", return_value=provider): - with patch("deerflow.persistence.engine.get_session_factory", return_value=sf): - from app.gateway.app import _ensure_admin_user - - asyncio.run(_ensure_admin_user(app)) - - store.asearch.assert_called_once() - - -def test_admin_exists_no_admin_row_skips_migration(): - """Admin count > 0 but DB row missing (edge case) → skip migration gracefully.""" - provider = _make_provider(admin_count=2) - sf = _make_session_factory(admin_row=None) - store = AsyncMock() - app = _make_app_stub(store=store) - - with patch("app.gateway.deps.get_local_provider", return_value=provider): - with patch("deerflow.persistence.engine.get_session_factory", return_value=sf): - from app.gateway.app import _ensure_admin_user - - asyncio.run(_ensure_admin_user(app)) - - store.asearch.assert_not_called() - - -def test_admin_exists_no_store_skips_migration(): - """Admin exists, row found, but no store → no crash, no migration.""" - from uuid import uuid4 - - admin_row = MagicMock() - admin_row.id = uuid4() - - provider = _make_provider(admin_count=1) - sf = _make_session_factory(admin_row=admin_row) - app = _make_app_stub(store=None) - - with patch("app.gateway.deps.get_local_provider", return_value=provider): - with patch("deerflow.persistence.engine.get_session_factory", return_value=sf): - from app.gateway.app import _ensure_admin_user - - asyncio.run(_ensure_admin_user(app)) - - # No assertion needed — just verify no crash - - -def test_admin_exists_session_factory_none_skips_migration(): - """get_session_factory() returns None → return early, no crash.""" - provider = _make_provider(admin_count=1) - store = AsyncMock() - app = _make_app_stub(store=store) - - with patch("app.gateway.deps.get_local_provider", return_value=provider): - with patch("deerflow.persistence.engine.get_session_factory", return_value=None): - from app.gateway.app import _ensure_admin_user - - asyncio.run(_ensure_admin_user(app)) - - store.asearch.assert_not_called() - - -def test_migration_failure_is_non_fatal(): - """_migrate_orphaned_threads exception is caught and logged.""" - from uuid import uuid4 - - admin_row = MagicMock() - admin_row.id = uuid4() - - provider = _make_provider(admin_count=1) - sf = _make_session_factory(admin_row=admin_row) - store = AsyncMock() - store.asearch = AsyncMock(side_effect=RuntimeError("store crashed")) - app = _make_app_stub(store=store) - - with patch("app.gateway.deps.get_local_provider", return_value=provider): - with patch("deerflow.persistence.engine.get_session_factory", return_value=sf): - from app.gateway.app import _ensure_admin_user - - # Should not raise - asyncio.run(_ensure_admin_user(app)) - - -# ── Section 5.1-5.6 upgrade path: orphan thread migration ──────────────── - - -def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows(): - """First boot finds Store-only legacy threads → stamps admin's id. - - Validates the **TC-UPG-02 upgrade story**: an operator running main - (no auth) accumulates threads in the LangGraph Store namespace - ``("threads",)`` with no ``metadata.user_id``. After upgrading to - feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should - rewrite each unowned item with the freshly created admin's id. - """ - from app.gateway.app import _migrate_orphaned_threads - - # Three orphan items + one already-owned item that should be left alone. - items = [ - SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}), - SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}), - SimpleNamespace(key="t3", value={"metadata": {}}), - SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}), - ] - store = AsyncMock() - # asearch returns the entire batch on first call, then an empty page - # to terminate _iter_store_items. - store.asearch = AsyncMock(side_effect=[items, []]) - aput_calls: list[tuple[tuple, str, dict]] = [] - - async def _record_aput(namespace, key, value): - aput_calls.append((namespace, key, value)) - - store.aput = AsyncMock(side_effect=_record_aput) - - migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42")) - - # Three orphan rows migrated, one preserved. - assert migrated == 3 - assert len(aput_calls) == 3 - rewritten_keys = {call[1] for call in aput_calls} - assert rewritten_keys == {"t1", "t2", "t3"} - # Each rewrite carries the new user_id; titles preserved where present. - by_key = {call[1]: call[2] for call in aput_calls} - assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42" - assert by_key["t1"]["metadata"]["title"] == "old-thread-1" - assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42" - # The pre-owned item must NOT have been rewritten. - assert "t4" not in rewritten_keys - - -def test_migrate_orphaned_threads_empty_store_is_noop(): - """A store with no threads → migrated == 0, no aput calls.""" - from app.gateway.app import _migrate_orphaned_threads - - store = AsyncMock() - store.asearch = AsyncMock(return_value=[]) - store.aput = AsyncMock() - - migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42")) - - assert migrated == 0 - store.aput.assert_not_called() - - -def test_iter_store_items_walks_multiple_pages(): - """Cursor-style iterator pulls every page until a short page terminates. - - Closes the regression where the old hardcoded ``limit=1000`` could - silently drop orphans on a large pre-upgrade dataset. The migration - code path uses the default ``page_size=500``; this test pins the - iterator with ``page_size=2`` so it stays fast. - """ - from app.gateway.app import _iter_store_items - - page_a = [SimpleNamespace(key=f"t{i}", value={"metadata": {}}) for i in range(2)] - page_b = [SimpleNamespace(key=f"t{i + 2}", value={"metadata": {}}) for i in range(2)] - page_c: list = [] # short page → loop terminates - - store = AsyncMock() - store.asearch = AsyncMock(side_effect=[page_a, page_b, page_c]) - - async def _collect(): - return [item.key async for item in _iter_store_items(store, ("threads",), page_size=2)] - - keys = asyncio.run(_collect()) - assert keys == ["t0", "t1", "t2", "t3"] - # Three asearch calls: full batch, full batch, empty terminator - assert store.asearch.await_count == 3 - - -def test_iter_store_items_terminates_on_short_page(): - """A short page (len < page_size) ends the loop without an extra call.""" - from app.gateway.app import _iter_store_items - - page = [SimpleNamespace(key=f"t{i}", value={}) for i in range(3)] - store = AsyncMock() - store.asearch = AsyncMock(return_value=page) - - async def _collect(): - return [item.key async for item in _iter_store_items(store, ("threads",), page_size=10)] - - keys = asyncio.run(_collect()) - assert keys == ["t0", "t1", "t2"] - # Only one call — no terminator probe needed because len(batch) < page_size - assert store.asearch.await_count == 1 diff --git a/backend/tests/test_feedback.py b/backend/tests/test_feedback.py deleted file mode 100644 index a592bdd22..000000000 --- a/backend/tests/test_feedback.py +++ /dev/null @@ -1,289 +0,0 @@ -"""Tests for FeedbackRepository and follow-up association. - -Uses temp SQLite DB for ORM tests. -""" - -import pytest - -from deerflow.persistence.feedback import FeedbackRepository - - -async def _make_feedback_repo(tmp_path): - from deerflow.persistence.engine import get_session_factory, init_engine - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - return FeedbackRepository(get_session_factory()) - - -async def _cleanup(): - from deerflow.persistence.engine import close_engine - - await close_engine() - - -# -- FeedbackRepository -- - - -class TestFeedbackRepository: - @pytest.mark.anyio - async def test_create_positive(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - record = await repo.create(run_id="r1", thread_id="t1", rating=1) - assert record["feedback_id"] - assert record["rating"] == 1 - assert record["run_id"] == "r1" - assert record["thread_id"] == "t1" - assert "created_at" in record - await _cleanup() - - @pytest.mark.anyio - async def test_create_negative_with_comment(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - record = await repo.create( - run_id="r1", - thread_id="t1", - rating=-1, - comment="Response was inaccurate", - ) - assert record["rating"] == -1 - assert record["comment"] == "Response was inaccurate" - await _cleanup() - - @pytest.mark.anyio - async def test_create_with_message_id(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42") - assert record["message_id"] == "msg-42" - await _cleanup() - - @pytest.mark.anyio - async def test_create_with_owner(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") - assert record["user_id"] == "user-1" - await _cleanup() - - @pytest.mark.anyio - async def test_create_invalid_rating_zero(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - with pytest.raises(ValueError): - await repo.create(run_id="r1", thread_id="t1", rating=0) - await _cleanup() - - @pytest.mark.anyio - async def test_create_invalid_rating_five(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - with pytest.raises(ValueError): - await repo.create(run_id="r1", thread_id="t1", rating=5) - await _cleanup() - - @pytest.mark.anyio - async def test_get(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - created = await repo.create(run_id="r1", thread_id="t1", rating=1) - fetched = await repo.get(created["feedback_id"]) - assert fetched is not None - assert fetched["feedback_id"] == created["feedback_id"] - assert fetched["rating"] == 1 - await _cleanup() - - @pytest.mark.anyio - async def test_get_nonexistent(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - assert await repo.get("nonexistent") is None - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_run(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") - await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2") - await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1") - results = await repo.list_by_run("t1", "r1", user_id=None) - assert len(results) == 2 - assert all(r["run_id"] == "r1" for r in results) - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_thread(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - await repo.create(run_id="r1", thread_id="t1", rating=1) - await repo.create(run_id="r2", thread_id="t1", rating=-1) - await repo.create(run_id="r3", thread_id="t2", rating=1) - results = await repo.list_by_thread("t1") - assert len(results) == 2 - assert all(r["thread_id"] == "t1" for r in results) - await _cleanup() - - @pytest.mark.anyio - async def test_delete(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - created = await repo.create(run_id="r1", thread_id="t1", rating=1) - deleted = await repo.delete(created["feedback_id"]) - assert deleted is True - assert await repo.get(created["feedback_id"]) is None - await _cleanup() - - @pytest.mark.anyio - async def test_delete_nonexistent(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - deleted = await repo.delete("nonexistent") - assert deleted is False - await _cleanup() - - @pytest.mark.anyio - async def test_aggregate_by_run(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") - await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2") - await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3") - stats = await repo.aggregate_by_run("t1", "r1") - assert stats["total"] == 3 - assert stats["positive"] == 2 - assert stats["negative"] == 1 - assert stats["run_id"] == "r1" - await _cleanup() - - @pytest.mark.anyio - async def test_aggregate_empty(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - stats = await repo.aggregate_by_run("t1", "r1") - assert stats["total"] == 0 - assert stats["positive"] == 0 - assert stats["negative"] == 0 - await _cleanup() - - @pytest.mark.anyio - async def test_upsert_creates_new(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") - assert record["rating"] == 1 - assert record["feedback_id"] - assert record["user_id"] == "u1" - await _cleanup() - - @pytest.mark.anyio - async def test_upsert_updates_existing(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") - second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind") - assert second["feedback_id"] == first["feedback_id"] - assert second["rating"] == -1 - assert second["comment"] == "changed my mind" - await _cleanup() - - @pytest.mark.anyio - async def test_upsert_different_users_separate(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") - r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2") - assert r1["feedback_id"] != r2["feedback_id"] - assert r1["rating"] == 1 - assert r2["rating"] == -1 - await _cleanup() - - @pytest.mark.anyio - async def test_upsert_invalid_rating(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - with pytest.raises(ValueError): - await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1") - await _cleanup() - - @pytest.mark.anyio - async def test_delete_by_run(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") - deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1") - assert deleted is True - results = await repo.list_by_run("t1", "r1", user_id="u1") - assert len(results) == 0 - await _cleanup() - - @pytest.mark.anyio - async def test_delete_by_run_nonexistent(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1") - assert deleted is False - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_thread_grouped(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") - await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1") - await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1") - grouped = await repo.list_by_thread_grouped("t1", user_id="u1") - assert "r1" in grouped - assert "r2" in grouped - assert "r3" not in grouped - assert grouped["r1"]["rating"] == 1 - assert grouped["r2"]["rating"] == -1 - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_thread_grouped_empty(self, tmp_path): - repo = await _make_feedback_repo(tmp_path) - grouped = await repo.list_by_thread_grouped("t1", user_id="u1") - assert grouped == {} - await _cleanup() - - -# -- Follow-up association -- - - -class TestFollowUpAssociation: - @pytest.mark.anyio - async def test_run_records_follow_up_via_memory_store(self): - """MemoryRunStore stores follow_up_to_run_id in kwargs.""" - from deerflow.runtime.runs.store.memory import MemoryRunStore - - store = MemoryRunStore() - await store.put("r1", thread_id="t1", status="success") - # MemoryRunStore doesn't have follow_up_to_run_id as a top-level param, - # but it can be passed via metadata - await store.put("r2", thread_id="t1", metadata={"follow_up_to_run_id": "r1"}) - run = await store.get("r2") - assert run["metadata"]["follow_up_to_run_id"] == "r1" - - @pytest.mark.anyio - async def test_human_message_has_follow_up_metadata(self): - """human_message event metadata includes follow_up_to_run_id.""" - from deerflow.runtime.events.store.memory import MemoryRunEventStore - - event_store = MemoryRunEventStore() - await event_store.put( - thread_id="t1", - run_id="r2", - event_type="human_message", - category="message", - content="Tell me more about that", - metadata={"follow_up_to_run_id": "r1"}, - ) - messages = await event_store.list_messages("t1") - assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1" - - @pytest.mark.anyio - async def test_follow_up_auto_detection_logic(self): - """Simulate the auto-detection: latest successful run becomes follow_up_to.""" - from deerflow.runtime.runs.store.memory import MemoryRunStore - - store = MemoryRunStore() - await store.put("r1", thread_id="t1", status="success") - await store.put("r2", thread_id="t1", status="error") - - # Auto-detect: list_by_thread returns newest first - recent = await store.list_by_thread("t1", limit=1) - follow_up = None - if recent and recent[0].get("status") == "success": - follow_up = recent[0]["run_id"] - # r2 (error) is newest, so no follow_up detected - assert follow_up is None - - # Now add a successful run - await store.put("r3", thread_id="t1", status="success") - recent = await store.list_by_thread("t1", limit=1) - follow_up = None - if recent and recent[0].get("status") == "success": - follow_up = recent[0]["run_id"] - assert follow_up == "r3" diff --git a/backend/tests/test_gateway_services.py b/backend/tests/test_gateway_services.py deleted file mode 100644 index 782306e38..000000000 --- a/backend/tests/test_gateway_services.py +++ /dev/null @@ -1,342 +0,0 @@ -"""Tests for app.gateway.services — run lifecycle service layer.""" - -from __future__ import annotations - -import json - - -def test_format_sse_basic(): - from app.gateway.services import format_sse - - frame = format_sse("metadata", {"run_id": "abc"}) - assert frame.startswith("event: metadata\n") - assert "data: " in frame - parsed = json.loads(frame.split("data: ")[1].split("\n")[0]) - assert parsed["run_id"] == "abc" - - -def test_format_sse_with_event_id(): - from app.gateway.services import format_sse - - frame = format_sse("metadata", {"run_id": "abc"}, event_id="123-0") - assert "id: 123-0" in frame - - -def test_format_sse_end_event_null(): - from app.gateway.services import format_sse - - frame = format_sse("end", None) - assert "data: null" in frame - - -def test_format_sse_no_event_id(): - from app.gateway.services import format_sse - - frame = format_sse("values", {"x": 1}) - assert "id:" not in frame - - -def test_normalize_stream_modes_none(): - from app.gateway.services import normalize_stream_modes - - assert normalize_stream_modes(None) == ["values"] - - -def test_normalize_stream_modes_string(): - from app.gateway.services import normalize_stream_modes - - assert normalize_stream_modes("messages-tuple") == ["messages-tuple"] - - -def test_normalize_stream_modes_list(): - from app.gateway.services import normalize_stream_modes - - assert normalize_stream_modes(["values", "messages-tuple"]) == ["values", "messages-tuple"] - - -def test_normalize_stream_modes_empty_list(): - from app.gateway.services import normalize_stream_modes - - assert normalize_stream_modes([]) == ["values"] - - -def test_normalize_input_none(): - from app.gateway.services import normalize_input - - assert normalize_input(None) == {} - - -def test_normalize_input_with_messages(): - from app.gateway.services import normalize_input - - result = normalize_input({"messages": [{"role": "user", "content": "hi"}]}) - assert len(result["messages"]) == 1 - assert result["messages"][0].content == "hi" - - -def test_normalize_input_passthrough(): - from app.gateway.services import normalize_input - - result = normalize_input({"custom_key": "value"}) - assert result == {"custom_key": "value"} - - -def test_build_run_config_basic(): - from app.gateway.services import build_run_config - - config = build_run_config("thread-1", None, None) - assert config["configurable"]["thread_id"] == "thread-1" - assert config["recursion_limit"] == 100 - - -def test_build_run_config_with_overrides(): - from app.gateway.services import build_run_config - - config = build_run_config( - "thread-1", - {"configurable": {"model_name": "gpt-4"}, "tags": ["test"]}, - {"user": "alice"}, - ) - assert config["configurable"]["model_name"] == "gpt-4" - assert config["tags"] == ["test"] - assert config["metadata"]["user"] == "alice" - - -# --------------------------------------------------------------------------- -# Regression tests for issue #1644: -# assistant_id not mapped to agent_name → custom agent SOUL.md never loaded -# --------------------------------------------------------------------------- - - -def test_build_run_config_custom_agent_injects_agent_name(): - """Custom assistant_id must be forwarded as configurable['agent_name'].""" - from app.gateway.services import build_run_config - - config = build_run_config("thread-1", None, None, assistant_id="finalis") - assert config["configurable"]["agent_name"] == "finalis" - - -def test_build_run_config_lead_agent_no_agent_name(): - """'lead_agent' assistant_id must NOT inject configurable['agent_name'].""" - from app.gateway.services import build_run_config - - config = build_run_config("thread-1", None, None, assistant_id="lead_agent") - assert "agent_name" not in config["configurable"] - - -def test_build_run_config_none_assistant_id_no_agent_name(): - """None assistant_id must NOT inject configurable['agent_name'].""" - from app.gateway.services import build_run_config - - config = build_run_config("thread-1", None, None, assistant_id=None) - assert "agent_name" not in config["configurable"] - - -def test_build_run_config_explicit_agent_name_not_overwritten(): - """An explicit configurable['agent_name'] in the request must take precedence.""" - from app.gateway.services import build_run_config - - config = build_run_config( - "thread-1", - {"configurable": {"agent_name": "explicit-agent"}}, - None, - assistant_id="other-agent", - ) - assert config["configurable"]["agent_name"] == "explicit-agent" - - -def test_resolve_agent_factory_returns_make_lead_agent(): - """resolve_agent_factory always returns make_lead_agent regardless of assistant_id.""" - from app.gateway.services import resolve_agent_factory - from deerflow.agents.lead_agent.agent import make_lead_agent - - assert resolve_agent_factory(None) is make_lead_agent - assert resolve_agent_factory("lead_agent") is make_lead_agent - assert resolve_agent_factory("finalis") is make_lead_agent - assert resolve_agent_factory("custom-agent-123") is make_lead_agent - - -# --------------------------------------------------------------------------- - -# --------------------------------------------------------------------------- -# Regression tests for issue #1699: -# context field in langgraph-compat requests not merged into configurable -# --------------------------------------------------------------------------- - - -def test_run_create_request_accepts_context(): - """RunCreateRequest must accept the ``context`` field without dropping it.""" - from app.gateway.routers.thread_runs import RunCreateRequest - - body = RunCreateRequest( - input={"messages": [{"role": "user", "content": "hi"}]}, - context={ - "model_name": "deepseek-v3", - "thinking_enabled": True, - "is_plan_mode": True, - "subagent_enabled": True, - "thread_id": "some-thread-id", - }, - ) - assert body.context is not None - assert body.context["model_name"] == "deepseek-v3" - assert body.context["is_plan_mode"] is True - assert body.context["subagent_enabled"] is True - - -def test_run_create_request_context_defaults_to_none(): - """RunCreateRequest without context should default to None (backward compat).""" - from app.gateway.routers.thread_runs import RunCreateRequest - - body = RunCreateRequest(input=None) - assert body.context is None - - -def test_context_merges_into_configurable(): - """Context values must be merged into config['configurable'] by start_run. - - Since start_run is async and requires many dependencies, we test the - merging logic directly by simulating what start_run does. - """ - from app.gateway.services import build_run_config - - # Simulate the context merging logic from start_run - config = build_run_config("thread-1", None, None) - - context = { - "model_name": "deepseek-v3", - "mode": "ultra", - "reasoning_effort": "high", - "thinking_enabled": True, - "is_plan_mode": True, - "subagent_enabled": True, - "max_concurrent_subagents": 5, - "thread_id": "should-be-ignored", - } - - _CONTEXT_CONFIGURABLE_KEYS = { - "model_name", - "mode", - "thinking_enabled", - "reasoning_effort", - "is_plan_mode", - "subagent_enabled", - "max_concurrent_subagents", - } - configurable = config.setdefault("configurable", {}) - for key in _CONTEXT_CONFIGURABLE_KEYS: - if key in context: - configurable.setdefault(key, context[key]) - - assert config["configurable"]["model_name"] == "deepseek-v3" - assert config["configurable"]["thinking_enabled"] is True - assert config["configurable"]["is_plan_mode"] is True - assert config["configurable"]["subagent_enabled"] is True - assert config["configurable"]["max_concurrent_subagents"] == 5 - assert config["configurable"]["reasoning_effort"] == "high" - assert config["configurable"]["mode"] == "ultra" - # thread_id from context should NOT override the one from build_run_config - assert config["configurable"]["thread_id"] == "thread-1" - # Non-allowlisted keys should not appear - assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS} - - -def test_context_does_not_override_existing_configurable(): - """Values already in config.configurable must NOT be overridden by context.""" - from app.gateway.services import build_run_config - - config = build_run_config( - "thread-1", - {"configurable": {"model_name": "gpt-4", "is_plan_mode": False}}, - None, - ) - - context = { - "model_name": "deepseek-v3", - "is_plan_mode": True, - "subagent_enabled": True, - } - - _CONTEXT_CONFIGURABLE_KEYS = { - "model_name", - "mode", - "thinking_enabled", - "reasoning_effort", - "is_plan_mode", - "subagent_enabled", - "max_concurrent_subagents", - } - configurable = config.setdefault("configurable", {}) - for key in _CONTEXT_CONFIGURABLE_KEYS: - if key in context: - configurable.setdefault(key, context[key]) - - # Existing values must NOT be overridden - assert config["configurable"]["model_name"] == "gpt-4" - assert config["configurable"]["is_plan_mode"] is False - # New values should be added - assert config["configurable"]["subagent_enabled"] is True - - -# --------------------------------------------------------------------------- -# build_run_config — context / configurable precedence (LangGraph >= 0.6.0) -# --------------------------------------------------------------------------- - - -def test_build_run_config_with_context(): - """When caller sends 'context', prefer it over 'configurable'.""" - from app.gateway.services import build_run_config - - config = build_run_config( - "thread-1", - {"context": {"user_id": "u-42", "thread_id": "thread-1"}}, - None, - ) - assert "context" in config - assert config["context"]["user_id"] == "u-42" - assert "configurable" not in config - assert config["recursion_limit"] == 100 - - -def test_build_run_config_context_plus_configurable_warns(caplog): - """When caller sends both 'context' and 'configurable', prefer 'context' and log a warning.""" - import logging - - from app.gateway.services import build_run_config - - with caplog.at_level(logging.WARNING, logger="app.gateway.services"): - config = build_run_config( - "thread-1", - { - "context": {"user_id": "u-42"}, - "configurable": {"model_name": "gpt-4"}, - }, - None, - ) - assert "context" in config - assert config["context"]["user_id"] == "u-42" - assert "configurable" not in config - assert any("both 'context' and 'configurable'" in r.message for r in caplog.records) - - -def test_build_run_config_context_passthrough_other_keys(): - """Non-conflicting keys from request_config are still passed through when context is used.""" - from app.gateway.services import build_run_config - - config = build_run_config( - "thread-1", - {"context": {"thread_id": "thread-1"}, "tags": ["prod"]}, - None, - ) - assert config["context"]["thread_id"] == "thread-1" - assert "configurable" not in config - assert config["tags"] == ["prod"] - - -def test_build_run_config_no_request_config(): - """When request_config is None, fall back to basic configurable with thread_id.""" - from app.gateway.services import build_run_config - - config = build_run_config("thread-abc", None, None) - assert config["configurable"] == {"thread_id": "thread-abc"} - assert "context" not in config diff --git a/backend/tests/test_memory_thread_meta_isolation.py b/backend/tests/test_memory_thread_meta_isolation.py deleted file mode 100644 index 25c9298f0..000000000 --- a/backend/tests/test_memory_thread_meta_isolation.py +++ /dev/null @@ -1,156 +0,0 @@ -"""Owner isolation tests for MemoryThreadMetaStore. - -Mirrors the SQL-backed tests in test_owner_isolation.py but exercises -the in-memory LangGraph Store backend used when database.backend=memory. -""" - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest -from langgraph.store.memory import InMemoryStore - -from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore -from deerflow.runtime.user_context import reset_current_user, set_current_user - -USER_A = SimpleNamespace(id="user-a", email="a@test.local") -USER_B = SimpleNamespace(id="user-b", email="b@test.local") - - -def _as_user(user): - class _Ctx: - def __enter__(self): - self._token = set_current_user(user) - return user - - def __exit__(self, *exc): - reset_current_user(self._token) - - return _Ctx() - - -@pytest.fixture -def store(): - return MemoryThreadMetaStore(InMemoryStore()) - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_search_isolation(store): - """search() returns only threads owned by the current user.""" - with _as_user(USER_A): - await store.create("t-alpha", display_name="A's thread") - with _as_user(USER_B): - await store.create("t-beta", display_name="B's thread") - - with _as_user(USER_A): - results = await store.search() - assert [r["thread_id"] for r in results] == ["t-alpha"] - - with _as_user(USER_B): - results = await store.search() - assert [r["thread_id"] for r in results] == ["t-beta"] - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_get_isolation(store): - """get() returns None for threads owned by another user.""" - with _as_user(USER_A): - await store.create("t-alpha", display_name="A's thread") - - with _as_user(USER_B): - assert await store.get("t-alpha") is None - - with _as_user(USER_A): - result = await store.get("t-alpha") - assert result is not None - assert result["display_name"] == "A's thread" - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_update_display_name_denied(store): - """User B cannot rename User A's thread.""" - with _as_user(USER_A): - await store.create("t-alpha", display_name="original") - - with _as_user(USER_B): - await store.update_display_name("t-alpha", "hacked") - - with _as_user(USER_A): - row = await store.get("t-alpha") - assert row is not None - assert row["display_name"] == "original" - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_update_status_denied(store): - """User B cannot change status of User A's thread.""" - with _as_user(USER_A): - await store.create("t-alpha") - - with _as_user(USER_B): - await store.update_status("t-alpha", "error") - - with _as_user(USER_A): - row = await store.get("t-alpha") - assert row is not None - assert row["status"] == "idle" - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_update_metadata_denied(store): - """User B cannot modify metadata of User A's thread.""" - with _as_user(USER_A): - await store.create("t-alpha", metadata={"key": "original"}) - - with _as_user(USER_B): - await store.update_metadata("t-alpha", {"key": "hacked"}) - - with _as_user(USER_A): - row = await store.get("t-alpha") - assert row is not None - assert row["metadata"]["key"] == "original" - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_delete_denied(store): - """User B cannot delete User A's thread.""" - with _as_user(USER_A): - await store.create("t-alpha") - - with _as_user(USER_B): - await store.delete("t-alpha") - - with _as_user(USER_A): - row = await store.get("t-alpha") - assert row is not None - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_no_context_raises(store): - """Calling methods without user context raises RuntimeError.""" - with pytest.raises(RuntimeError, match="no user context is set"): - await store.search() - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_explicit_none_bypasses_filter(store): - """user_id=None bypasses isolation (migration/CLI escape hatch).""" - with _as_user(USER_A): - await store.create("t-alpha") - with _as_user(USER_B): - await store.create("t-beta") - - all_rows = await store.search(user_id=None) - assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"} - - row = await store.get("t-alpha", user_id=None) - assert row is not None diff --git a/backend/tests/test_owner_isolation.py b/backend/tests/test_owner_isolation.py deleted file mode 100644 index 33d21f3e3..000000000 --- a/backend/tests/test_owner_isolation.py +++ /dev/null @@ -1,465 +0,0 @@ -"""Cross-user isolation tests — non-negotiable safety gate. - -Mirrors TC-API-17..20 from backend/docs/AUTH_TEST_PLAN.md. A failure -here means users can see each other's data; PR must not merge. - -Architecture note ------------------ -These tests bypass the HTTP layer and exercise the storage-layer -owner filter directly by switching the ``user_context`` contextvar -between two users. The safety property under test is: - - After a repository write with user_id=A, a subsequent read with - user_id=B must not return the row, and vice versa. - -The HTTP layer is covered by test_auth_middleware.py, which proves -that a request cookie reaches the ``set_current_user`` call. Together -the two suites prove the full chain: - - cookie → middleware → contextvar → repository → isolation - -Every test in this file opts out of the autouse contextvar fixture -(``@pytest.mark.no_auto_user``) so it can set the contextvar to the -specific users it cares about. -""" - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest - -from deerflow.runtime.user_context import ( - reset_current_user, - set_current_user, -) - -USER_A = SimpleNamespace(id="user-a", email="a@test.local") -USER_B = SimpleNamespace(id="user-b", email="b@test.local") - - -async def _make_engines(tmp_path): - """Initialize the shared engine against a per-test SQLite DB. - - Returns a cleanup coroutine the caller should await at the end. - """ - from deerflow.persistence.engine import close_engine, init_engine - - url = f"sqlite+aiosqlite:///{tmp_path / 'isolation.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - return close_engine - - -def _as_user(user): - """Context manager-like helper that set/reset the contextvar.""" - - class _Ctx: - def __enter__(self): - self._token = set_current_user(user) - return user - - def __exit__(self, *exc): - reset_current_user(self._token) - - return _Ctx() - - -# ── TC-API-17 — threads_meta isolation ──────────────────────────────────── - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_thread_meta_cross_user_isolation(tmp_path): - from deerflow.persistence.engine import get_session_factory - from deerflow.persistence.thread_meta import ThreadMetaRepository - - cleanup = await _make_engines(tmp_path) - try: - repo = ThreadMetaRepository(get_session_factory()) - - # User A creates a thread. - with _as_user(USER_A): - await repo.create("t-alpha", display_name="A's private thread") - - # User B creates a thread. - with _as_user(USER_B): - await repo.create("t-beta", display_name="B's private thread") - - # User A must see only A's thread. - with _as_user(USER_A): - a_view = await repo.get("t-alpha") - assert a_view is not None - assert a_view["display_name"] == "A's private thread" - - # CRITICAL: User A must NOT see B's thread. - leaked = await repo.get("t-beta") - assert leaked is None, f"User A leaked User B's thread: {leaked}" - - # Search should only return A's threads. - results = await repo.search() - assert [r["thread_id"] for r in results] == ["t-alpha"] - - # User B must see only B's thread. - with _as_user(USER_B): - b_view = await repo.get("t-beta") - assert b_view is not None - assert b_view["display_name"] == "B's private thread" - - leaked = await repo.get("t-alpha") - assert leaked is None, f"User B leaked User A's thread: {leaked}" - - results = await repo.search() - assert [r["thread_id"] for r in results] == ["t-beta"] - finally: - await cleanup() - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_thread_meta_cross_user_mutation_denied(tmp_path): - """User B cannot update or delete a thread owned by User A.""" - from deerflow.persistence.engine import get_session_factory - from deerflow.persistence.thread_meta import ThreadMetaRepository - - cleanup = await _make_engines(tmp_path) - try: - repo = ThreadMetaRepository(get_session_factory()) - - with _as_user(USER_A): - await repo.create("t-alpha", display_name="original") - - # User B tries to rename A's thread — must be a no-op. - with _as_user(USER_B): - await repo.update_display_name("t-alpha", "hacked") - - # Verify the row is unchanged from A's perspective. - with _as_user(USER_A): - row = await repo.get("t-alpha") - assert row is not None - assert row["display_name"] == "original" - - # User B tries to delete A's thread — must be a no-op. - with _as_user(USER_B): - await repo.delete("t-alpha") - - # A's thread still exists. - with _as_user(USER_A): - row = await repo.get("t-alpha") - assert row is not None - finally: - await cleanup() - - -# ── TC-API-18 — runs isolation ──────────────────────────────────────────── - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_runs_cross_user_isolation(tmp_path): - from deerflow.persistence.engine import get_session_factory - from deerflow.persistence.run import RunRepository - - cleanup = await _make_engines(tmp_path) - try: - repo = RunRepository(get_session_factory()) - - with _as_user(USER_A): - await repo.put("run-a1", thread_id="t-alpha") - await repo.put("run-a2", thread_id="t-alpha") - - with _as_user(USER_B): - await repo.put("run-b1", thread_id="t-beta") - - # User A must see only A's runs. - with _as_user(USER_A): - r = await repo.get("run-a1") - assert r is not None - assert r["run_id"] == "run-a1" - - leaked = await repo.get("run-b1") - assert leaked is None, "User A leaked User B's run" - - a_runs = await repo.list_by_thread("t-alpha") - assert {r["run_id"] for r in a_runs} == {"run-a1", "run-a2"} - - # Listing B's thread from A's perspective: empty - empty = await repo.list_by_thread("t-beta") - assert empty == [] - - # User B must see only B's runs. - with _as_user(USER_B): - leaked = await repo.get("run-a1") - assert leaked is None, "User B leaked User A's run" - - b_runs = await repo.list_by_thread("t-beta") - assert [r["run_id"] for r in b_runs] == ["run-b1"] - finally: - await cleanup() - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_runs_cross_user_delete_denied(tmp_path): - from deerflow.persistence.engine import get_session_factory - from deerflow.persistence.run import RunRepository - - cleanup = await _make_engines(tmp_path) - try: - repo = RunRepository(get_session_factory()) - - with _as_user(USER_A): - await repo.put("run-a1", thread_id="t-alpha") - - # User B tries to delete A's run — no-op. - with _as_user(USER_B): - await repo.delete("run-a1") - - # A's run still exists. - with _as_user(USER_A): - row = await repo.get("run-a1") - assert row is not None - finally: - await cleanup() - - -# ── TC-API-19 — run_events isolation (CRITICAL: content leak) ───────────── - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_run_events_cross_user_isolation(tmp_path): - """run_events holds raw conversation content — most sensitive leak vector.""" - from deerflow.persistence.engine import get_session_factory - from deerflow.runtime.events.store.db import DbRunEventStore - - cleanup = await _make_engines(tmp_path) - try: - store = DbRunEventStore(get_session_factory()) - - with _as_user(USER_A): - await store.put( - thread_id="t-alpha", - run_id="run-a1", - event_type="human_message", - category="message", - content="User A private question", - ) - await store.put( - thread_id="t-alpha", - run_id="run-a1", - event_type="ai_message", - category="message", - content="User A private answer", - ) - - with _as_user(USER_B): - await store.put( - thread_id="t-beta", - run_id="run-b1", - event_type="human_message", - category="message", - content="User B private question", - ) - - # User A must see only A's events — CRITICAL. - with _as_user(USER_A): - msgs = await store.list_messages("t-alpha") - contents = [m["content"] for m in msgs] - assert "User A private question" in contents - assert "User A private answer" in contents - # CRITICAL: User B's content must not appear. - assert "User B private question" not in contents - - # Attempt to read B's thread by guessing thread_id. - leaked = await store.list_messages("t-beta") - assert leaked == [], f"User A leaked User B's messages: {leaked}" - - leaked_events = await store.list_events("t-beta", "run-b1") - assert leaked_events == [], "User A leaked User B's events" - - # count_messages must also be zero for B's thread from A's view. - count = await store.count_messages("t-beta") - assert count == 0 - - # User B must see only B's events. - with _as_user(USER_B): - msgs = await store.list_messages("t-beta") - contents = [m["content"] for m in msgs] - assert "User B private question" in contents - assert "User A private question" not in contents - assert "User A private answer" not in contents - - count = await store.count_messages("t-alpha") - assert count == 0 - finally: - await cleanup() - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_run_events_cross_user_delete_denied(tmp_path): - """User B cannot delete User A's event stream.""" - from deerflow.persistence.engine import get_session_factory - from deerflow.runtime.events.store.db import DbRunEventStore - - cleanup = await _make_engines(tmp_path) - try: - store = DbRunEventStore(get_session_factory()) - - with _as_user(USER_A): - await store.put( - thread_id="t-alpha", - run_id="run-a1", - event_type="human_message", - category="message", - content="hello", - ) - - # User B tries to wipe A's thread events. - with _as_user(USER_B): - removed = await store.delete_by_thread("t-alpha") - assert removed == 0, f"User B deleted {removed} of User A's events" - - # A's events still exist. - with _as_user(USER_A): - count = await store.count_messages("t-alpha") - assert count == 1 - finally: - await cleanup() - - -# ── TC-API-20 — feedback isolation ──────────────────────────────────────── - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_feedback_cross_user_isolation(tmp_path): - from deerflow.persistence.engine import get_session_factory - from deerflow.persistence.feedback import FeedbackRepository - - cleanup = await _make_engines(tmp_path) - try: - repo = FeedbackRepository(get_session_factory()) - - # User A submits positive feedback. - with _as_user(USER_A): - a_feedback = await repo.create( - run_id="run-a1", - thread_id="t-alpha", - rating=1, - comment="A liked this", - ) - - # User B submits negative feedback. - with _as_user(USER_B): - b_feedback = await repo.create( - run_id="run-b1", - thread_id="t-beta", - rating=-1, - comment="B disliked this", - ) - - # User A must see only A's feedback. - with _as_user(USER_A): - retrieved = await repo.get(a_feedback["feedback_id"]) - assert retrieved is not None - assert retrieved["comment"] == "A liked this" - - # CRITICAL: cannot read B's feedback by id. - leaked = await repo.get(b_feedback["feedback_id"]) - assert leaked is None, "User A leaked User B's feedback" - - # list_by_run for B's run must be empty. - empty = await repo.list_by_run("t-beta", "run-b1") - assert empty == [] - - # User B must see only B's feedback. - with _as_user(USER_B): - leaked = await repo.get(a_feedback["feedback_id"]) - assert leaked is None, "User B leaked User A's feedback" - - b_list = await repo.list_by_run("t-beta", "run-b1") - assert len(b_list) == 1 - assert b_list[0]["comment"] == "B disliked this" - finally: - await cleanup() - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_feedback_cross_user_delete_denied(tmp_path): - from deerflow.persistence.engine import get_session_factory - from deerflow.persistence.feedback import FeedbackRepository - - cleanup = await _make_engines(tmp_path) - try: - repo = FeedbackRepository(get_session_factory()) - - with _as_user(USER_A): - fb = await repo.create(run_id="run-a1", thread_id="t-alpha", rating=1) - - # User B tries to delete A's feedback — must return False (no-op). - with _as_user(USER_B): - deleted = await repo.delete(fb["feedback_id"]) - assert deleted is False, "User B deleted User A's feedback" - - # A's feedback still retrievable. - with _as_user(USER_A): - row = await repo.get(fb["feedback_id"]) - assert row is not None - finally: - await cleanup() - - -# ── Regression: AUTO sentinel without contextvar must raise ─────────────── - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_repository_without_context_raises(tmp_path): - """Defense-in-depth: calling repo methods without a user context errors.""" - from deerflow.persistence.engine import get_session_factory - from deerflow.persistence.thread_meta import ThreadMetaRepository - - cleanup = await _make_engines(tmp_path) - try: - repo = ThreadMetaRepository(get_session_factory()) - # Contextvar is explicitly unset under @pytest.mark.no_auto_user. - with pytest.raises(RuntimeError, match="no user context is set"): - await repo.get("anything") - finally: - await cleanup() - - -# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ── - - -@pytest.mark.anyio -@pytest.mark.no_auto_user -async def test_explicit_none_bypasses_filter(tmp_path): - """Migration scripts pass user_id=None to see all rows regardless of owner.""" - from deerflow.persistence.engine import get_session_factory - from deerflow.persistence.thread_meta import ThreadMetaRepository - - cleanup = await _make_engines(tmp_path) - try: - repo = ThreadMetaRepository(get_session_factory()) - - # Seed data as two different users. - with _as_user(USER_A): - await repo.create("t-alpha") - with _as_user(USER_B): - await repo.create("t-beta") - - # Migration-style read: no contextvar, explicit None bypass. - all_rows = await repo.search(user_id=None) - thread_ids = {r["thread_id"] for r in all_rows} - assert thread_ids == {"t-alpha", "t-beta"} - - # Explicit get with None does not apply the filter either. - row_a = await repo.get("t-alpha", user_id=None) - assert row_a is not None - row_b = await repo.get("t-beta", user_id=None) - assert row_b is not None - finally: - await cleanup() diff --git a/backend/tests/test_persistence_scaffold.py b/backend/tests/test_persistence_scaffold.py deleted file mode 100644 index 178a08e84..000000000 --- a/backend/tests/test_persistence_scaffold.py +++ /dev/null @@ -1,233 +0,0 @@ -"""Tests for the persistence layer scaffolding. - -Tests: -1. DatabaseConfig property derivation (paths, URLs) -2. MemoryRunStore CRUD + user_id filtering -3. Base.to_dict() via inspect mixin -4. Engine init/close lifecycle (memory + SQLite) -5. Postgres missing-dep error message -""" - -from datetime import UTC, datetime - -import pytest - -from deerflow.config.database_config import DatabaseConfig -from deerflow.runtime.runs.store.memory import MemoryRunStore - -# -- DatabaseConfig -- - - -class TestDatabaseConfig: - def test_defaults(self): - c = DatabaseConfig() - assert c.backend == "memory" - assert c.pool_size == 5 - - def test_sqlite_paths_unified(self): - c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata") - assert c.sqlite_path.endswith("deerflow.db") - assert "mydata" in c.sqlite_path - # Backward-compatible aliases point to the same file - assert c.checkpointer_sqlite_path == c.sqlite_path - assert c.app_sqlite_path == c.sqlite_path - - def test_app_sqlalchemy_url_sqlite(self): - c = DatabaseConfig(backend="sqlite", sqlite_dir="./data") - url = c.app_sqlalchemy_url - assert url.startswith("sqlite+aiosqlite:///") - assert "deerflow.db" in url - - def test_app_sqlalchemy_url_postgres(self): - c = DatabaseConfig( - backend="postgres", - postgres_url="postgresql://u:p@h:5432/db", - ) - url = c.app_sqlalchemy_url - assert url.startswith("postgresql+asyncpg://") - assert "u:p@h:5432/db" in url - - def test_app_sqlalchemy_url_postgres_already_asyncpg(self): - c = DatabaseConfig( - backend="postgres", - postgres_url="postgresql+asyncpg://u:p@h:5432/db", - ) - url = c.app_sqlalchemy_url - assert url.count("asyncpg") == 1 - - def test_memory_has_no_url(self): - c = DatabaseConfig(backend="memory") - with pytest.raises(ValueError, match="No SQLAlchemy URL"): - _ = c.app_sqlalchemy_url - - -# -- MemoryRunStore -- - - -class TestMemoryRunStore: - @pytest.fixture - def store(self): - return MemoryRunStore() - - @pytest.mark.anyio - async def test_put_and_get(self, store): - await store.put("r1", thread_id="t1", status="pending") - row = await store.get("r1") - assert row is not None - assert row["run_id"] == "r1" - assert row["status"] == "pending" - - @pytest.mark.anyio - async def test_get_missing_returns_none(self, store): - assert await store.get("nope") is None - - @pytest.mark.anyio - async def test_update_status(self, store): - await store.put("r1", thread_id="t1") - await store.update_status("r1", "running") - assert (await store.get("r1"))["status"] == "running" - - @pytest.mark.anyio - async def test_update_status_with_error(self, store): - await store.put("r1", thread_id="t1") - await store.update_status("r1", "error", error="boom") - row = await store.get("r1") - assert row["status"] == "error" - assert row["error"] == "boom" - - @pytest.mark.anyio - async def test_list_by_thread(self, store): - await store.put("r1", thread_id="t1") - await store.put("r2", thread_id="t1") - await store.put("r3", thread_id="t2") - rows = await store.list_by_thread("t1") - assert len(rows) == 2 - assert all(r["thread_id"] == "t1" for r in rows) - - @pytest.mark.anyio - async def test_list_by_thread_owner_filter(self, store): - await store.put("r1", thread_id="t1", user_id="alice") - await store.put("r2", thread_id="t1", user_id="bob") - rows = await store.list_by_thread("t1", user_id="alice") - assert len(rows) == 1 - assert rows[0]["user_id"] == "alice" - - @pytest.mark.anyio - async def test_owner_none_returns_all(self, store): - await store.put("r1", thread_id="t1", user_id="alice") - await store.put("r2", thread_id="t1", user_id="bob") - rows = await store.list_by_thread("t1", user_id=None) - assert len(rows) == 2 - - @pytest.mark.anyio - async def test_delete(self, store): - await store.put("r1", thread_id="t1") - await store.delete("r1") - assert await store.get("r1") is None - - @pytest.mark.anyio - async def test_delete_nonexistent_is_noop(self, store): - await store.delete("nope") # should not raise - - @pytest.mark.anyio - async def test_list_pending(self, store): - await store.put("r1", thread_id="t1", status="pending") - await store.put("r2", thread_id="t1", status="running") - await store.put("r3", thread_id="t2", status="pending") - pending = await store.list_pending() - assert len(pending) == 2 - assert all(r["status"] == "pending" for r in pending) - - @pytest.mark.anyio - async def test_list_pending_respects_before(self, store): - past = "2020-01-01T00:00:00+00:00" - future = "2099-01-01T00:00:00+00:00" - await store.put("r1", thread_id="t1", status="pending", created_at=past) - await store.put("r2", thread_id="t1", status="pending", created_at=future) - pending = await store.list_pending(before=datetime.now(UTC).isoformat()) - assert len(pending) == 1 - assert pending[0]["run_id"] == "r1" - - @pytest.mark.anyio - async def test_list_pending_fifo_order(self, store): - await store.put("r2", thread_id="t1", status="pending", created_at="2024-01-02T00:00:00+00:00") - await store.put("r1", thread_id="t1", status="pending", created_at="2024-01-01T00:00:00+00:00") - pending = await store.list_pending() - assert pending[0]["run_id"] == "r1" - - -# -- Base.to_dict mixin -- - - -class TestBaseToDictMixin: - @pytest.mark.anyio - async def test_to_dict_and_exclude(self, tmp_path): - """Create a temp SQLite DB with a minimal model, verify to_dict.""" - from sqlalchemy import String - from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine - from sqlalchemy.orm import Mapped, mapped_column - - from deerflow.persistence.base import Base - - class _Tmp(Base): - __tablename__ = "_tmp_test" - id: Mapped[str] = mapped_column(String(64), primary_key=True) - name: Mapped[str] = mapped_column(String(128)) - - engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}") - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - sf = async_sessionmaker(engine, expire_on_commit=False) - async with sf() as session: - session.add(_Tmp(id="1", name="hello")) - await session.commit() - obj = await session.get(_Tmp, "1") - - assert obj.to_dict() == {"id": "1", "name": "hello"} - assert obj.to_dict(exclude={"name"}) == {"id": "1"} - assert "_Tmp" in repr(obj) - - await engine.dispose() - - -# -- Engine lifecycle -- - - -class TestEngineLifecycle: - @pytest.mark.anyio - async def test_memory_is_noop(self): - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - - await init_engine("memory") - assert get_session_factory() is None - await close_engine() - - @pytest.mark.anyio - async def test_sqlite_creates_engine(self, tmp_path): - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - sf = get_session_factory() - assert sf is not None - async with sf() as session: - assert session is not None - await close_engine() - assert get_session_factory() is None - - @pytest.mark.anyio - async def test_postgres_without_asyncpg_gives_actionable_error(self): - """If asyncpg is not installed, error message tells user what to do.""" - from deerflow.persistence.engine import init_engine - - try: - import asyncpg # noqa: F401 - - pytest.skip("asyncpg is installed -- cannot test missing-dep path") - except ImportError: - # asyncpg is not installed — this is the expected state for this test. - # We proceed to verify that init_engine raises an actionable ImportError. - pass # noqa: S110 — intentionally ignored - with pytest.raises(ImportError, match="uv sync --extra postgres"): - await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x") diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py deleted file mode 100644 index 2b22b2c6f..000000000 --- a/backend/tests/test_run_event_store.py +++ /dev/null @@ -1,500 +0,0 @@ -"""Tests for RunEventStore contract across all backends. - -Uses a helper to create the store for each backend type. -Memory tests run directly; DB and JSONL tests create stores inside each test. -""" - -import pytest - -from deerflow.runtime.events.store.memory import MemoryRunEventStore - - -@pytest.fixture -def store(): - return MemoryRunEventStore() - - -# -- Basic write and query -- - - -class TestPutAndSeq: - @pytest.mark.anyio - async def test_put_returns_dict_with_seq(self, store): - record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello") - assert "seq" in record - assert record["seq"] == 1 - assert record["thread_id"] == "t1" - assert record["run_id"] == "r1" - assert record["event_type"] == "human_message" - assert record["category"] == "message" - assert record["content"] == "hello" - assert "created_at" in record - - @pytest.mark.anyio - async def test_seq_strictly_increasing_same_thread(self, store): - r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - r2 = await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message") - r3 = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace") - assert r1["seq"] == 1 - assert r2["seq"] == 2 - assert r3["seq"] == 3 - - @pytest.mark.anyio - async def test_seq_independent_across_threads(self, store): - r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - r2 = await store.put(thread_id="t2", run_id="r2", event_type="human_message", category="message") - assert r1["seq"] == 1 - assert r2["seq"] == 1 - - @pytest.mark.anyio - async def test_put_respects_provided_created_at(self, store): - ts = "2024-06-01T12:00:00+00:00" - record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", created_at=ts) - assert record["created_at"] == ts - - @pytest.mark.anyio - async def test_put_metadata_preserved(self, store): - meta = {"model": "gpt-4", "tokens": 100} - record = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", metadata=meta) - assert record["metadata"] == meta - - -# -- list_messages -- - - -class TestListMessages: - @pytest.mark.anyio - async def test_only_returns_message_category(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace") - await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle") - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["category"] == "message" - - @pytest.mark.anyio - async def test_ascending_seq_order(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="first") - await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="second") - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="third") - messages = await store.list_messages("t1") - seqs = [m["seq"] for m in messages] - assert seqs == sorted(seqs) - - @pytest.mark.anyio - async def test_before_seq_pagination(self, store): - for i in range(10): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i)) - messages = await store.list_messages("t1", before_seq=6, limit=3) - assert len(messages) == 3 - assert [m["seq"] for m in messages] == [3, 4, 5] - - @pytest.mark.anyio - async def test_after_seq_pagination(self, store): - for i in range(10): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i)) - messages = await store.list_messages("t1", after_seq=7, limit=3) - assert len(messages) == 3 - assert [m["seq"] for m in messages] == [8, 9, 10] - - @pytest.mark.anyio - async def test_limit_restricts_count(self, store): - for _ in range(20): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - messages = await store.list_messages("t1", limit=5) - assert len(messages) == 5 - - @pytest.mark.anyio - async def test_cross_run_unified_ordering(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message") - await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message") - messages = await store.list_messages("t1") - assert [m["seq"] for m in messages] == [1, 2, 3, 4] - assert messages[0]["run_id"] == "r1" - assert messages[2]["run_id"] == "r2" - - @pytest.mark.anyio - async def test_default_returns_latest(self, store): - for _ in range(10): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - messages = await store.list_messages("t1", limit=3) - assert [m["seq"] for m in messages] == [8, 9, 10] - - -# -- list_events -- - - -class TestListEvents: - @pytest.mark.anyio - async def test_returns_all_categories_for_run(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace") - await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle") - events = await store.list_events("t1", "r1") - assert len(events) == 3 - - @pytest.mark.anyio - async def test_event_types_filter(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="llm_start", category="trace") - await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace") - await store.put(thread_id="t1", run_id="r1", event_type="tool_start", category="trace") - events = await store.list_events("t1", "r1", event_types=["llm_end"]) - assert len(events) == 1 - assert events[0]["event_type"] == "llm_end" - - @pytest.mark.anyio - async def test_only_returns_specified_run(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace") - await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace") - events = await store.list_events("t1", "r1") - assert len(events) == 1 - assert events[0]["run_id"] == "r1" - - -# -- list_messages_by_run -- - - -class TestListMessagesByRun: - @pytest.mark.anyio - async def test_only_messages_for_specified_run(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace") - await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message") - messages = await store.list_messages_by_run("t1", "r1") - assert len(messages) == 1 - assert messages[0]["run_id"] == "r1" - assert messages[0]["category"] == "message" - - -# -- count_messages -- - - -class TestCountMessages: - @pytest.mark.anyio - async def test_counts_only_message_category(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message") - await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace") - assert await store.count_messages("t1") == 2 - - -# -- put_batch -- - - -class TestPutBatch: - @pytest.mark.anyio - async def test_batch_assigns_seq(self, store): - events = [ - {"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message", "content": "a"}, - {"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message", "content": "b"}, - {"thread_id": "t1", "run_id": "r1", "event_type": "llm_end", "category": "trace"}, - ] - results = await store.put_batch(events) - assert len(results) == 3 - assert all("seq" in r for r in results) - - @pytest.mark.anyio - async def test_batch_seq_strictly_increasing(self, store): - events = [ - {"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message"}, - {"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message"}, - ] - results = await store.put_batch(events) - assert results[0]["seq"] == 1 - assert results[1]["seq"] == 2 - - -# -- delete -- - - -class TestDelete: - @pytest.mark.anyio - async def test_delete_by_thread(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message") - await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace") - count = await store.delete_by_thread("t1") - assert count == 3 - assert await store.list_messages("t1") == [] - assert await store.count_messages("t1") == 0 - - @pytest.mark.anyio - async def test_delete_by_run(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message") - await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace") - count = await store.delete_by_run("t1", "r2") - assert count == 2 - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["run_id"] == "r1" - - @pytest.mark.anyio - async def test_delete_nonexistent_thread_returns_zero(self, store): - assert await store.delete_by_thread("nope") == 0 - - @pytest.mark.anyio - async def test_delete_nonexistent_run_returns_zero(self, store): - await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - assert await store.delete_by_run("t1", "nope") == 0 - - @pytest.mark.anyio - async def test_delete_nonexistent_thread_for_run_returns_zero(self, store): - assert await store.delete_by_run("nope", "r1") == 0 - - -# -- Edge cases -- - - -class TestEdgeCases: - @pytest.mark.anyio - async def test_empty_thread_list_messages(self, store): - assert await store.list_messages("empty") == [] - - @pytest.mark.anyio - async def test_empty_run_list_events(self, store): - assert await store.list_events("empty", "r1") == [] - - @pytest.mark.anyio - async def test_empty_thread_count_messages(self, store): - assert await store.count_messages("empty") == 0 - - -# -- DB-specific tests -- - - -class TestDbRunEventStore: - """Tests for DbRunEventStore with temp SQLite.""" - - @pytest.mark.anyio - async def test_basic_crud(self, tmp_path): - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - s = DbRunEventStore(get_session_factory()) - - r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi") - assert r["seq"] == 1 - r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello") - assert r2["seq"] == 2 - - messages = await s.list_messages("t1") - assert len(messages) == 2 - - count = await s.count_messages("t1") - assert count == 2 - - await close_engine() - - @pytest.mark.anyio - async def test_trace_content_truncation(self, tmp_path): - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - s = DbRunEventStore(get_session_factory(), max_trace_content=100) - - long = "x" * 200 - r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long) - assert len(r["content"]) == 100 - assert r["metadata"].get("content_truncated") is True - - # message content NOT truncated - m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long) - assert len(m["content"]) == 200 - - await close_engine() - - @pytest.mark.anyio - async def test_pagination(self, tmp_path): - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - s = DbRunEventStore(get_session_factory()) - - for i in range(10): - await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i)) - - # before_seq - msgs = await s.list_messages("t1", before_seq=6, limit=3) - assert [m["seq"] for m in msgs] == [3, 4, 5] - - # after_seq - msgs = await s.list_messages("t1", after_seq=7, limit=3) - assert [m["seq"] for m in msgs] == [8, 9, 10] - - # default (latest) - msgs = await s.list_messages("t1", limit=3) - assert [m["seq"] for m in msgs] == [8, 9, 10] - - await close_engine() - - @pytest.mark.anyio - async def test_delete(self, tmp_path): - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - s = DbRunEventStore(get_session_factory()) - - await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message") - c = await s.delete_by_run("t1", "r2") - assert c == 1 - assert await s.count_messages("t1") == 1 - - c = await s.delete_by_thread("t1") - assert c == 1 - assert await s.count_messages("t1") == 0 - - await close_engine() - - @pytest.mark.anyio - async def test_put_batch_seq_continuity(self, tmp_path): - """Batch write produces continuous seq values with no gaps.""" - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.runtime.events.store.db import DbRunEventStore - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - s = DbRunEventStore(get_session_factory()) - - events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"} for _ in range(50)] - results = await s.put_batch(events) - seqs = [r["seq"] for r in results] - assert seqs == list(range(1, 51)) - await close_engine() - - -# -- Factory tests -- - - -class TestMakeRunEventStore: - """Tests for the make_run_event_store factory function.""" - - @pytest.mark.anyio - async def test_memory_backend_default(self): - from deerflow.runtime.events.store import make_run_event_store - - store = make_run_event_store(None) - assert type(store).__name__ == "MemoryRunEventStore" - - @pytest.mark.anyio - async def test_memory_backend_explicit(self): - from unittest.mock import MagicMock - - from deerflow.runtime.events.store import make_run_event_store - - config = MagicMock() - config.backend = "memory" - store = make_run_event_store(config) - assert type(store).__name__ == "MemoryRunEventStore" - - @pytest.mark.anyio - async def test_db_backend_with_engine(self, tmp_path): - from unittest.mock import MagicMock - - from deerflow.persistence.engine import close_engine, init_engine - from deerflow.runtime.events.store import make_run_event_store - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - - config = MagicMock() - config.backend = "db" - config.max_trace_content = 10240 - store = make_run_event_store(config) - assert type(store).__name__ == "DbRunEventStore" - await close_engine() - - @pytest.mark.anyio - async def test_db_backend_no_engine_falls_back(self): - """db backend without engine falls back to memory.""" - from unittest.mock import MagicMock - - from deerflow.persistence.engine import close_engine, init_engine - from deerflow.runtime.events.store import make_run_event_store - - await init_engine("memory") # no engine created - - config = MagicMock() - config.backend = "db" - store = make_run_event_store(config) - assert type(store).__name__ == "MemoryRunEventStore" - await close_engine() - - @pytest.mark.anyio - async def test_jsonl_backend(self): - from unittest.mock import MagicMock - - from deerflow.runtime.events.store import make_run_event_store - - config = MagicMock() - config.backend = "jsonl" - store = make_run_event_store(config) - assert type(store).__name__ == "JsonlRunEventStore" - - @pytest.mark.anyio - async def test_unknown_backend_raises(self): - from unittest.mock import MagicMock - - from deerflow.runtime.events.store import make_run_event_store - - config = MagicMock() - config.backend = "redis" - with pytest.raises(ValueError, match="Unknown"): - make_run_event_store(config) - - -# -- JSONL-specific tests -- - - -class TestJsonlRunEventStore: - @pytest.mark.anyio - async def test_basic_crud(self, tmp_path): - from deerflow.runtime.events.store.jsonl import JsonlRunEventStore - - s = JsonlRunEventStore(base_dir=tmp_path / "jsonl") - r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi") - assert r["seq"] == 1 - messages = await s.list_messages("t1") - assert len(messages) == 1 - - @pytest.mark.anyio - async def test_file_at_correct_path(self, tmp_path): - from deerflow.runtime.events.store.jsonl import JsonlRunEventStore - - s = JsonlRunEventStore(base_dir=tmp_path / "jsonl") - await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists() - - @pytest.mark.anyio - async def test_cross_run_messages(self, tmp_path): - from deerflow.runtime.events.store.jsonl import JsonlRunEventStore - - s = JsonlRunEventStore(base_dir=tmp_path / "jsonl") - await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message") - messages = await s.list_messages("t1") - assert len(messages) == 2 - assert [m["seq"] for m in messages] == [1, 2] - - @pytest.mark.anyio - async def test_delete_by_run(self, tmp_path): - from deerflow.runtime.events.store.jsonl import JsonlRunEventStore - - s = JsonlRunEventStore(base_dir=tmp_path / "jsonl") - await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") - await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message") - c = await s.delete_by_run("t1", "r2") - assert c == 1 - assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists() - assert await s.count_messages("t1") == 1 diff --git a/backend/tests/test_run_event_store_pagination.py b/backend/tests/test_run_event_store_pagination.py deleted file mode 100644 index ac5ba4c2d..000000000 --- a/backend/tests/test_run_event_store_pagination.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Tests for paginated list_messages_by_run across all RunEventStore backends.""" -import pytest - -from deerflow.runtime.events.store.memory import MemoryRunEventStore - - -@pytest.fixture -def base_store(): - return MemoryRunEventStore() - - -@pytest.mark.anyio -async def test_list_messages_by_run_default_returns_all(base_store): - store = base_store - for i in range(7): - await store.put( - thread_id="t1", run_id="run-a", - event_type="human_message" if i % 2 == 0 else "ai_message", - category="message", content=f"msg-a-{i}", - ) - for i in range(3): - await store.put( - thread_id="t1", run_id="run-b", - event_type="human_message", category="message", content=f"msg-b-{i}", - ) - await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace") - - msgs = await store.list_messages_by_run("t1", "run-a") - assert len(msgs) == 7 - assert all(m["category"] == "message" for m in msgs) - assert all(m["run_id"] == "run-a" for m in msgs) - - -@pytest.mark.anyio -async def test_list_messages_by_run_with_limit(base_store): - store = base_store - for i in range(7): - await store.put( - thread_id="t1", run_id="run-a", - event_type="human_message" if i % 2 == 0 else "ai_message", - category="message", content=f"msg-a-{i}", - ) - - msgs = await store.list_messages_by_run("t1", "run-a", limit=3) - assert len(msgs) == 3 - seqs = [m["seq"] for m in msgs] - assert seqs == sorted(seqs) - - -@pytest.mark.anyio -async def test_list_messages_by_run_after_seq(base_store): - store = base_store - for i in range(7): - await store.put( - thread_id="t1", run_id="run-a", - event_type="human_message" if i % 2 == 0 else "ai_message", - category="message", content=f"msg-a-{i}", - ) - - all_msgs = await store.list_messages_by_run("t1", "run-a") - cursor_seq = all_msgs[2]["seq"] - msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50) - assert all(m["seq"] > cursor_seq for m in msgs) - assert len(msgs) == 4 - - -@pytest.mark.anyio -async def test_list_messages_by_run_before_seq(base_store): - store = base_store - for i in range(7): - await store.put( - thread_id="t1", run_id="run-a", - event_type="human_message" if i % 2 == 0 else "ai_message", - category="message", content=f"msg-a-{i}", - ) - - all_msgs = await store.list_messages_by_run("t1", "run-a") - cursor_seq = all_msgs[4]["seq"] - msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50) - assert all(m["seq"] < cursor_seq for m in msgs) - assert len(msgs) == 4 - - -@pytest.mark.anyio -async def test_list_messages_by_run_does_not_include_other_run(base_store): - store = base_store - for i in range(7): - await store.put( - thread_id="t1", run_id="run-a", - event_type="human_message", category="message", content=f"msg-a-{i}", - ) - for i in range(3): - await store.put( - thread_id="t1", run_id="run-b", - event_type="human_message", category="message", content=f"msg-b-{i}", - ) - - msgs = await store.list_messages_by_run("t1", "run-b") - assert len(msgs) == 3 - assert all(m["run_id"] == "run-b" for m in msgs) - - -@pytest.mark.anyio -async def test_list_messages_by_run_empty_run(base_store): - store = base_store - msgs = await store.list_messages_by_run("t1", "nonexistent") - assert msgs == [] diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py deleted file mode 100644 index 0a274df33..000000000 --- a/backend/tests/test_run_journal.py +++ /dev/null @@ -1,385 +0,0 @@ -"""Tests for RunJournal callback handler. - -Uses MemoryRunEventStore as the backend for direct event inspection. -""" - -import asyncio -from unittest.mock import MagicMock -from uuid import uuid4 - -import pytest - -from deerflow.runtime.events.store.memory import MemoryRunEventStore -from deerflow.runtime.journal import RunJournal - - -@pytest.fixture -def journal_setup(): - store = MemoryRunEventStore() - j = RunJournal("r1", "t1", store, flush_threshold=100) - return j, store - - -def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None): - """Create a mock LLM response with a message. - - model_dump() returns checkpoint-aligned format matching real AIMessage. - """ - msg = MagicMock() - msg.type = "ai" - msg.content = content - msg.id = f"msg-{id(msg)}" - msg.tool_calls = tool_calls or [] - msg.invalid_tool_calls = [] - msg.response_metadata = {"model_name": "test-model"} - msg.usage_metadata = usage - msg.additional_kwargs = additional_kwargs or {} - msg.name = None - # model_dump returns checkpoint-aligned format - msg.model_dump.return_value = { - "content": content, - "additional_kwargs": additional_kwargs or {}, - "response_metadata": {"model_name": "test-model"}, - "type": "ai", - "name": None, - "id": msg.id, - "tool_calls": tool_calls or [], - "invalid_tool_calls": [], - "usage_metadata": usage, - } - - gen = MagicMock() - gen.message = msg - - response = MagicMock() - response.generations = [[gen]] - return response - - -class TestLlmCallbacks: - @pytest.mark.anyio - async def test_on_llm_end_produces_trace_event(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) - await j.flush() - events = await store.list_events("t1", "r1") - trace_events = [e for e in events if e["event_type"] == "llm.ai.response"] - assert len(trace_events) == 1 - assert trace_events[0]["category"] == "message" - - @pytest.mark.anyio - async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["event_type"] == "llm.ai.response" - # Content is checkpoint-aligned model_dump format - assert messages[0]["content"]["type"] == "ai" - assert messages[0]["content"]["content"] == "Answer" - - @pytest.mark.anyio - async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup): - """LLM response with pending tool_calls emits llm.ai.response with tool_calls in content.""" - j, store = journal_setup - run_id = uuid4() - j.on_llm_end( - _make_llm_response("Let me search", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]), - run_id=run_id, - parent_run_id=None, - tags=["lead_agent"], - ) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["event_type"] == "llm.ai.response" - assert len(messages[0]["content"]["tool_calls"]) == 1 - - @pytest.mark.anyio - async def test_on_llm_end_subagent_no_ai_message(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"]) - j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, parent_run_id=None, tags=["subagent:research"]) - await j.flush() - messages = await store.list_messages("t1") - # subagent responses still emit llm.ai.response with category="message" - assert len(messages) == 1 - - @pytest.mark.anyio - async def test_token_accumulation(self, journal_setup): - j, store = journal_setup - usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} - usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} - j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) - assert j._total_input_tokens == 30 - assert j._total_output_tokens == 15 - assert j._total_tokens == 45 - assert j._llm_call_count == 2 - - @pytest.mark.anyio - async def test_total_tokens_computed_from_input_output(self, journal_setup): - """If total_tokens is 0, it should be computed from input + output.""" - j, store = journal_setup - j.on_llm_end( - _make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}), - run_id=uuid4(), - parent_run_id=None, - tags=["lead_agent"], - ) - assert j._total_tokens == 150 - - @pytest.mark.anyio - async def test_caller_token_classification(self, journal_setup): - j, store = journal_setup - usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} - j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"]) - j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarization"]) - # token tracking not broken by caller type - assert j._total_tokens == 45 - assert j._llm_call_count == 3 - - @pytest.mark.anyio - async def test_usage_metadata_none_no_crash(self, journal_setup): - j, store = journal_setup - j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) - await j.flush() - - @pytest.mark.anyio - async def test_latency_tracking(self, journal_setup): - j, store = journal_setup - run_id = uuid4() - j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) - j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) - await j.flush() - events = await store.list_events("t1", "r1") - llm_resp = [e for e in events if e["event_type"] == "llm.ai.response"][0] - assert "latency_ms" in llm_resp["metadata"] - assert llm_resp["metadata"]["latency_ms"] is not None - - -class TestLifecycleCallbacks: - @pytest.mark.anyio - async def test_chain_start_end_produce_trace_events(self, journal_setup): - j, store = journal_setup - j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) - j.on_chain_end({}, run_id=uuid4()) - await asyncio.sleep(0.05) - await j.flush() - events = await store.list_events("t1", "r1") - types = {e["event_type"] for e in events} - assert "run.start" in types - assert "run.end" in types - - @pytest.mark.anyio - async def test_nested_chain_no_run_start(self, journal_setup): - """Nested chains (parent_run_id set) should NOT produce run.start.""" - j, store = journal_setup - parent_id = uuid4() - j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id) - j.on_chain_end({}, run_id=uuid4()) - await j.flush() - events = await store.list_events("t1", "r1") - assert not any(e["event_type"] == "run.start" for e in events) - - -class TestToolCallbacks: - @pytest.mark.anyio - async def test_tool_end_with_tool_message(self, journal_setup): - """on_tool_end with a ToolMessage stores it as llm.tool.result.""" - from langchain_core.messages import ToolMessage - - j, store = journal_setup - tool_msg = ToolMessage(content="results", tool_call_id="call_1", name="web_search") - j.on_tool_end(tool_msg, run_id=uuid4()) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["event_type"] == "llm.tool.result" - assert messages[0]["content"]["type"] == "tool" - - @pytest.mark.anyio - async def test_tool_end_with_command_unwraps_tool_message(self, journal_setup): - """on_tool_end with Command(update={'messages':[ToolMessage]}) unwraps inner message.""" - from langchain_core.messages import ToolMessage - from langgraph.types import Command - - j, store = journal_setup - inner = ToolMessage(content="file list", tool_call_id="call_2", name="present_files") - cmd = Command(update={"messages": [inner]}) - j.on_tool_end(cmd, run_id=uuid4()) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["event_type"] == "llm.tool.result" - assert messages[0]["content"]["content"] == "file list" - - @pytest.mark.anyio - async def test_on_tool_error_no_crash(self, journal_setup): - """on_tool_error should not crash (no event emitted by default).""" - j, store = journal_setup - j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch") - await j.flush() - # Base implementation does not emit tool_error — just verify no crash - events = await store.list_events("t1", "r1") - assert isinstance(events, list) - - -class TestCustomEvents: - @pytest.mark.anyio - async def test_on_custom_event_not_implemented(self, journal_setup): - """RunJournal does not implement on_custom_event — no crash expected.""" - j, store = journal_setup - # BaseCallbackHandler.on_custom_event is a no-op by default - j.on_custom_event("task_running", {"task_id": "t1"}, run_id=uuid4()) - await j.flush() - events = await store.list_events("t1", "r1") - assert isinstance(events, list) - - -class TestBufferFlush: - @pytest.mark.anyio - async def test_flush_threshold(self, journal_setup): - j, store = journal_setup - j._flush_threshold = 2 - # Each on_llm_end emits 1 event - j.on_llm_end(_make_llm_response("A"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) - assert len(j._buffer) == 1 - j.on_llm_end(_make_llm_response("B"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) - # At threshold the buffer should have been flushed asynchronously - await asyncio.sleep(0.1) - events = await store.list_events("t1", "r1") - assert len(events) >= 2 - - @pytest.mark.anyio - async def test_events_retained_when_no_loop(self, journal_setup): - """Events buffered in a sync (no-loop) context should survive - until the async flush() in the finally block.""" - j, store = journal_setup - j._flush_threshold = 1 - - original = asyncio.get_running_loop - - def no_loop(): - raise RuntimeError("no running event loop") - - asyncio.get_running_loop = no_loop - try: - j._put(event_type="llm.ai.response", category="message", content="test") - finally: - asyncio.get_running_loop = original - - assert len(j._buffer) == 1 - await j.flush() - events = await store.list_events("t1", "r1") - assert any(e["event_type"] == "llm.ai.response" for e in events) - - -class TestIdentifyCaller: - def test_lead_agent_tag(self, journal_setup): - j, _ = journal_setup - assert j._identify_caller(["lead_agent"]) == "lead_agent" - - def test_subagent_tag(self, journal_setup): - j, _ = journal_setup - assert j._identify_caller(["subagent:research"]) == "subagent:research" - - def test_middleware_tag(self, journal_setup): - j, _ = journal_setup - assert j._identify_caller(["middleware:summarization"]) == "middleware:summarization" - - def test_no_tags_returns_lead_agent(self, journal_setup): - j, _ = journal_setup - assert j._identify_caller([]) == "lead_agent" - assert j._identify_caller(None) == "lead_agent" - - -class TestChainErrorCallback: - @pytest.mark.anyio - async def test_on_chain_error_writes_run_error(self, journal_setup): - j, store = journal_setup - j.on_chain_error(ValueError("boom"), run_id=uuid4()) - await asyncio.sleep(0.05) - await j.flush() - events = await store.list_events("t1", "r1") - error_events = [e for e in events if e["event_type"] == "run.error"] - assert len(error_events) == 1 - assert "boom" in error_events[0]["content"] - assert error_events[0]["metadata"]["error_type"] == "ValueError" - - -class TestTokenTrackingDisabled: - @pytest.mark.anyio - async def test_track_token_usage_false(self): - store = MemoryRunEventStore() - j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100) - j.on_llm_end( - _make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), - run_id=uuid4(), - parent_run_id=None, - tags=["lead_agent"], - ) - data = j.get_completion_data() - assert data["total_tokens"] == 0 - assert data["llm_call_count"] == 0 - - -class TestConvenienceFields: - @pytest.mark.anyio - async def test_first_human_message_via_set(self, journal_setup): - j, _ = journal_setup - j.set_first_human_message("What is AI?") - data = j.get_completion_data() - assert data["first_human_message"] == "What is AI?" - - @pytest.mark.anyio - async def test_get_completion_data(self, journal_setup): - j, _ = journal_setup - j._total_tokens = 100 - j._msg_count = 5 - data = j.get_completion_data() - assert data["total_tokens"] == 100 - assert data["message_count"] == 5 - - -class TestMiddlewareEvents: - @pytest.mark.anyio - async def test_record_middleware_uses_middleware_category(self, journal_setup): - j, store = journal_setup - j.record_middleware( - "title", - name="TitleMiddleware", - hook="after_model", - action="generate_title", - changes={"title": "Test Title", "thread_id": "t1"}, - ) - await j.flush() - events = await store.list_events("t1", "r1") - mw_events = [e for e in events if e["event_type"] == "middleware:title"] - assert len(mw_events) == 1 - assert mw_events[0]["category"] == "middleware" - assert mw_events[0]["content"]["name"] == "TitleMiddleware" - assert mw_events[0]["content"]["hook"] == "after_model" - assert mw_events[0]["content"]["action"] == "generate_title" - assert mw_events[0]["content"]["changes"]["title"] == "Test Title" - - @pytest.mark.anyio - async def test_middleware_tag_variants(self, journal_setup): - """Different middleware tags produce distinct event_types.""" - j, store = journal_setup - j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={}) - j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={}) - await j.flush() - events = await store.list_events("t1", "r1") - event_types = {e["event_type"] for e in events} - assert "middleware:title" in event_types - assert "middleware:guardrail" in event_types - - diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py deleted file mode 100644 index 34ab9b492..000000000 --- a/backend/tests/test_run_repository.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Tests for RunRepository (SQLAlchemy-backed RunStore). - -Uses a temp SQLite DB to test ORM-backed CRUD operations. -""" - -import pytest - -from deerflow.persistence.run import RunRepository - - -async def _make_repo(tmp_path): - from deerflow.persistence.engine import get_session_factory, init_engine - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - return RunRepository(get_session_factory()) - - -async def _cleanup(): - from deerflow.persistence.engine import close_engine - - await close_engine() - - -class TestRunRepository: - @pytest.mark.anyio - async def test_put_and_get(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", status="pending") - row = await repo.get("r1") - assert row is not None - assert row["run_id"] == "r1" - assert row["thread_id"] == "t1" - assert row["status"] == "pending" - await _cleanup() - - @pytest.mark.anyio - async def test_get_missing_returns_none(self, tmp_path): - repo = await _make_repo(tmp_path) - assert await repo.get("nope") is None - await _cleanup() - - @pytest.mark.anyio - async def test_update_status(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1") - await repo.update_status("r1", "running") - row = await repo.get("r1") - assert row["status"] == "running" - await _cleanup() - - @pytest.mark.anyio - async def test_update_status_with_error(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1") - await repo.update_status("r1", "error", error="boom") - row = await repo.get("r1") - assert row["status"] == "error" - assert row["error"] == "boom" - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_thread(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1") - await repo.put("r2", thread_id="t1") - await repo.put("r3", thread_id="t2") - rows = await repo.list_by_thread("t1") - assert len(rows) == 2 - assert all(r["thread_id"] == "t1" for r in rows) - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_thread_owner_filter(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", user_id="alice") - await repo.put("r2", thread_id="t1", user_id="bob") - rows = await repo.list_by_thread("t1", user_id="alice") - assert len(rows) == 1 - assert rows[0]["user_id"] == "alice" - await _cleanup() - - @pytest.mark.anyio - async def test_delete(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1") - await repo.delete("r1") - assert await repo.get("r1") is None - await _cleanup() - - @pytest.mark.anyio - async def test_delete_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.delete("nope") # should not raise - await _cleanup() - - @pytest.mark.anyio - async def test_list_pending(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", status="pending") - await repo.put("r2", thread_id="t1", status="running") - await repo.put("r3", thread_id="t2", status="pending") - pending = await repo.list_pending() - assert len(pending) == 2 - assert all(r["status"] == "pending" for r in pending) - await _cleanup() - - @pytest.mark.anyio - async def test_update_run_completion(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", status="running") - await repo.update_run_completion( - "r1", - status="success", - total_input_tokens=100, - total_output_tokens=50, - total_tokens=150, - llm_call_count=2, - lead_agent_tokens=120, - subagent_tokens=20, - middleware_tokens=10, - message_count=3, - last_ai_message="The answer is 42", - first_human_message="What is the meaning?", - ) - row = await repo.get("r1") - assert row["status"] == "success" - assert row["total_tokens"] == 150 - assert row["llm_call_count"] == 2 - assert row["lead_agent_tokens"] == 120 - assert row["message_count"] == 3 - assert row["last_ai_message"] == "The answer is 42" - assert row["first_human_message"] == "What is the meaning?" - await _cleanup() - - @pytest.mark.anyio - async def test_metadata_preserved(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", metadata={"key": "value"}) - row = await repo.get("r1") - assert row["metadata"] == {"key": "value"} - await _cleanup() - - @pytest.mark.anyio - async def test_kwargs_with_non_serializable(self, tmp_path): - """kwargs containing non-JSON-serializable objects should be safely handled.""" - repo = await _make_repo(tmp_path) - - class Dummy: - pass - - await repo.put("r1", thread_id="t1", kwargs={"obj": Dummy()}) - row = await repo.get("r1") - assert "obj" in row["kwargs"] - await _cleanup() - - @pytest.mark.anyio - async def test_update_run_completion_preserves_existing_fields(self, tmp_path): - """update_run_completion does not overwrite thread_id or assistant_id.""" - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", assistant_id="agent1", status="running") - await repo.update_run_completion("r1", status="success", total_tokens=100) - row = await repo.get("r1") - assert row["thread_id"] == "t1" - assert row["assistant_id"] == "agent1" - assert row["total_tokens"] == 100 - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_thread_ordered_desc(self, tmp_path): - """list_by_thread returns newest first.""" - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", created_at="2024-01-01T00:00:00+00:00") - await repo.put("r2", thread_id="t1", created_at="2024-01-02T00:00:00+00:00") - rows = await repo.list_by_thread("t1") - assert rows[0]["run_id"] == "r2" - assert rows[1]["run_id"] == "r1" - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_thread_limit(self, tmp_path): - repo = await _make_repo(tmp_path) - for i in range(5): - await repo.put(f"r{i}", thread_id="t1") - rows = await repo.list_by_thread("t1", limit=2) - assert len(rows) == 2 - await _cleanup() - - @pytest.mark.anyio - async def test_owner_none_returns_all(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", user_id="alice") - await repo.put("r2", thread_id="t1", user_id="bob") - rows = await repo.list_by_thread("t1", user_id=None) - assert len(rows) == 2 - await _cleanup() diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py deleted file mode 100644 index 714ccdde1..000000000 --- a/backend/tests/test_run_worker_rollback.py +++ /dev/null @@ -1,214 +0,0 @@ -from unittest.mock import AsyncMock, call - -import pytest - -from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint - - -class FakeCheckpointer: - def __init__(self, *, put_result): - self.adelete_thread = AsyncMock() - self.aput = AsyncMock(return_value=put_result) - self.aput_writes = AsyncMock() - - -@pytest.mark.anyio -async def test_rollback_restores_snapshot_without_deleting_thread(): - checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) - - await _rollback_to_pre_run_checkpoint( - checkpointer=checkpointer, - thread_id="thread-1", - run_id="run-1", - pre_run_checkpoint_id="ckpt-1", - pre_run_snapshot={ - "checkpoint_ns": "", - "checkpoint": { - "id": "ckpt-1", - "channel_versions": {"messages": 3}, - "channel_values": {"messages": ["before"]}, - }, - "metadata": {"source": "input"}, - "pending_writes": [ - ("task-a", "messages", {"content": "first"}), - ("task-a", "status", "done"), - ("task-b", "events", {"type": "tool"}), - ], - }, - snapshot_capture_failed=False, - ) - - checkpointer.adelete_thread.assert_not_awaited() - checkpointer.aput.assert_awaited_once_with( - {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, - { - "id": "ckpt-1", - "channel_versions": {"messages": 3}, - "channel_values": {"messages": ["before"]}, - }, - {"source": "input"}, - {"messages": 3}, - ) - assert checkpointer.aput_writes.await_args_list == [ - call( - {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, - [("messages", {"content": "first"}), ("status", "done")], - task_id="task-a", - ), - call( - {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, - [("events", {"type": "tool"})], - task_id="task-b", - ), - ] - - -@pytest.mark.anyio -async def test_rollback_deletes_thread_when_no_snapshot_exists(): - checkpointer = FakeCheckpointer(put_result=None) - - await _rollback_to_pre_run_checkpoint( - checkpointer=checkpointer, - thread_id="thread-1", - run_id="run-1", - pre_run_checkpoint_id=None, - pre_run_snapshot=None, - snapshot_capture_failed=False, - ) - - checkpointer.adelete_thread.assert_awaited_once_with("thread-1") - checkpointer.aput.assert_not_awaited() - checkpointer.aput_writes.assert_not_awaited() - - -@pytest.mark.anyio -async def test_rollback_raises_when_restore_config_has_no_checkpoint_id(): - checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}) - - with pytest.raises(RuntimeError, match="did not return checkpoint_id"): - await _rollback_to_pre_run_checkpoint( - checkpointer=checkpointer, - thread_id="thread-1", - run_id="run-1", - pre_run_checkpoint_id="ckpt-1", - pre_run_snapshot={ - "checkpoint_ns": "", - "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, - "metadata": {}, - "pending_writes": [("task-a", "messages", "value")], - }, - snapshot_capture_failed=False, - ) - - checkpointer.adelete_thread.assert_not_awaited() - checkpointer.aput.assert_awaited_once() - checkpointer.aput_writes.assert_not_awaited() - - -@pytest.mark.anyio -async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace(): - checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) - - await _rollback_to_pre_run_checkpoint( - checkpointer=checkpointer, - thread_id="thread-1", - run_id="run-1", - pre_run_checkpoint_id="ckpt-1", - pre_run_snapshot={ - "checkpoint_ns": None, - "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, - "metadata": {}, - "pending_writes": [], - }, - snapshot_capture_failed=False, - ) - - checkpointer.aput.assert_awaited_once_with( - {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, - {"id": "ckpt-1", "channel_versions": {}}, - {}, - {}, - ) - - -@pytest.mark.anyio -async def test_rollback_raises_on_malformed_pending_write_not_a_tuple(): - """pending_writes containing a non-3-tuple item should raise RuntimeError.""" - checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) - - with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"): - await _rollback_to_pre_run_checkpoint( - checkpointer=checkpointer, - thread_id="thread-1", - run_id="run-1", - pre_run_checkpoint_id="ckpt-1", - pre_run_snapshot={ - "checkpoint_ns": "", - "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, - "metadata": {}, - "pending_writes": [ - ("task-a", "messages", "valid"), # valid - ["only", "two"], # malformed: only 2 elements - ], - }, - snapshot_capture_failed=False, - ) - - # aput succeeded but aput_writes should not be called due to malformed data - checkpointer.aput.assert_awaited_once() - checkpointer.aput_writes.assert_not_awaited() - - -@pytest.mark.anyio -async def test_rollback_raises_on_malformed_pending_write_non_string_channel(): - """pending_writes containing a non-string channel should raise RuntimeError.""" - checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) - - with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"): - await _rollback_to_pre_run_checkpoint( - checkpointer=checkpointer, - thread_id="thread-1", - run_id="run-1", - pre_run_checkpoint_id="ckpt-1", - pre_run_snapshot={ - "checkpoint_ns": "", - "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, - "metadata": {}, - "pending_writes": [ - ("task-a", 123, "value"), # malformed: channel is not a string - ], - }, - snapshot_capture_failed=False, - ) - - checkpointer.aput.assert_awaited_once() - checkpointer.aput_writes.assert_not_awaited() - - -@pytest.mark.anyio -async def test_rollback_propagates_aput_writes_failure(): - """If aput_writes fails, the exception should propagate (not be swallowed).""" - checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) - # Simulate aput_writes failure - checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost") - - with pytest.raises(RuntimeError, match="Database connection lost"): - await _rollback_to_pre_run_checkpoint( - checkpointer=checkpointer, - thread_id="thread-1", - run_id="run-1", - pre_run_checkpoint_id="ckpt-1", - pre_run_snapshot={ - "checkpoint_ns": "", - "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, - "metadata": {}, - "pending_writes": [ - ("task-a", "messages", "value"), - ], - }, - snapshot_capture_failed=False, - ) - - # aput succeeded, aput_writes was called but failed - checkpointer.aput.assert_awaited_once() - checkpointer.aput_writes.assert_awaited_once() diff --git a/backend/tests/test_runs_api_endpoints.py b/backend/tests/test_runs_api_endpoints.py deleted file mode 100644 index e6b73d865..000000000 --- a/backend/tests/test_runs_api_endpoints.py +++ /dev/null @@ -1,243 +0,0 @@ -"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints.""" -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from _router_auth_helpers import make_authed_test_app -from fastapi.testclient import TestClient - -from app.gateway.routers import runs - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_app(run_store=None, event_store=None, feedback_repo=None): - """Build a test FastAPI app with stub auth and mocked state.""" - app = make_authed_test_app() - app.include_router(runs.router) - - if run_store is not None: - app.state.run_store = run_store - if event_store is not None: - app.state.run_event_store = event_store - if feedback_repo is not None: - app.state.feedback_repo = feedback_repo - - return app - - -def _make_run_store(run_record: dict | None): - """Return an AsyncMock run store whose get() returns run_record.""" - store = MagicMock() - store.get = AsyncMock(return_value=run_record) - return store - - -def _make_event_store(rows: list[dict]): - """Return an AsyncMock event store whose list_messages_by_run() returns rows.""" - store = MagicMock() - store.list_messages_by_run = AsyncMock(return_value=rows) - return store - - -def _make_message(seq: int) -> dict: - return {"seq": seq, "event_type": "on_chat_model_stream", "category": "message", "content": f"msg-{seq}"} - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -def test_run_messages_returns_envelope(): - """GET /api/runs/{run_id}/messages returns {data: [...], has_more: bool}.""" - rows = [_make_message(i) for i in range(1, 4)] - run_record = {"run_id": "run-1", "thread_id": "thread-1"} - app = _make_app( - run_store=_make_run_store(run_record), - event_store=_make_event_store(rows), - ) - with TestClient(app) as client: - response = client.get("/api/runs/run-1/messages") - assert response.status_code == 200 - body = response.json() - assert "data" in body - assert "has_more" in body - assert body["has_more"] is False - assert len(body["data"]) == 3 - - -def test_run_messages_404_when_run_not_found(): - """Returns 404 when the run store returns None.""" - app = _make_app( - run_store=_make_run_store(None), - event_store=_make_event_store([]), - ) - with TestClient(app) as client: - response = client.get("/api/runs/missing-run/messages") - assert response.status_code == 404 - assert "missing-run" in response.json()["detail"] - - -def test_run_messages_has_more_true_when_extra_row_returned(): - """has_more=True when event store returns limit+1 rows.""" - # Default limit is 50; provide 51 rows - rows = [_make_message(i) for i in range(1, 52)] # 51 rows - run_record = {"run_id": "run-2", "thread_id": "thread-2"} - app = _make_app( - run_store=_make_run_store(run_record), - event_store=_make_event_store(rows), - ) - with TestClient(app) as client: - response = client.get("/api/runs/run-2/messages") - assert response.status_code == 200 - body = response.json() - assert body["has_more"] is True - assert len(body["data"]) == 50 # trimmed to limit - - -def test_run_messages_passes_after_seq_to_event_store(): - """after_seq query param is forwarded to event_store.list_messages_by_run.""" - rows = [_make_message(10)] - run_record = {"run_id": "run-3", "thread_id": "thread-3"} - event_store = _make_event_store(rows) - app = _make_app( - run_store=_make_run_store(run_record), - event_store=event_store, - ) - with TestClient(app) as client: - response = client.get("/api/runs/run-3/messages?after_seq=5") - assert response.status_code == 200 - event_store.list_messages_by_run.assert_awaited_once_with( - "thread-3", "run-3", - limit=51, # default limit(50) + 1 - before_seq=None, - after_seq=5, - ) - - -def test_run_messages_respects_custom_limit(): - """Custom limit is respected and capped at 200.""" - rows = [_make_message(i) for i in range(1, 6)] - run_record = {"run_id": "run-4", "thread_id": "thread-4"} - event_store = _make_event_store(rows) - app = _make_app( - run_store=_make_run_store(run_record), - event_store=event_store, - ) - with TestClient(app) as client: - response = client.get("/api/runs/run-4/messages?limit=10") - assert response.status_code == 200 - event_store.list_messages_by_run.assert_awaited_once_with( - "thread-4", "run-4", - limit=11, # 10 + 1 - before_seq=None, - after_seq=None, - ) - - -def test_run_messages_passes_before_seq_to_event_store(): - """before_seq query param is forwarded to event_store.list_messages_by_run.""" - rows = [_make_message(3)] - run_record = {"run_id": "run-5", "thread_id": "thread-5"} - event_store = _make_event_store(rows) - app = _make_app( - run_store=_make_run_store(run_record), - event_store=event_store, - ) - with TestClient(app) as client: - response = client.get("/api/runs/run-5/messages?before_seq=10") - assert response.status_code == 200 - event_store.list_messages_by_run.assert_awaited_once_with( - "thread-5", "run-5", - limit=51, - before_seq=10, - after_seq=None, - ) - - -def test_run_messages_empty_data(): - """Returns empty data list when no messages exist.""" - run_record = {"run_id": "run-6", "thread_id": "thread-6"} - app = _make_app( - run_store=_make_run_store(run_record), - event_store=_make_event_store([]), - ) - with TestClient(app) as client: - response = client.get("/api/runs/run-6/messages") - assert response.status_code == 200 - body = response.json() - assert body["data"] == [] - assert body["has_more"] is False - - -def _make_feedback_repo(rows: list[dict]): - """Return an AsyncMock feedback repo whose list_by_run() returns rows.""" - repo = MagicMock() - repo.list_by_run = AsyncMock(return_value=rows) - return repo - - -def _make_feedback(run_id: str, idx: int) -> dict: - return {"id": f"fb-{idx}", "run_id": run_id, "thread_id": "thread-x", "value": "up"} - - -# --------------------------------------------------------------------------- -# TestRunFeedback -# --------------------------------------------------------------------------- - - -class TestRunFeedback: - def test_returns_list_of_feedback_dicts(self): - """GET /api/runs/{run_id}/feedback returns a list of feedback dicts.""" - run_record = {"run_id": "run-fb-1", "thread_id": "thread-fb-1"} - rows = [_make_feedback("run-fb-1", i) for i in range(3)] - app = _make_app( - run_store=_make_run_store(run_record), - feedback_repo=_make_feedback_repo(rows), - ) - with TestClient(app) as client: - response = client.get("/api/runs/run-fb-1/feedback") - assert response.status_code == 200 - body = response.json() - assert isinstance(body, list) - assert len(body) == 3 - - def test_404_when_run_not_found(self): - """Returns 404 when run store returns None.""" - app = _make_app( - run_store=_make_run_store(None), - feedback_repo=_make_feedback_repo([]), - ) - with TestClient(app) as client: - response = client.get("/api/runs/missing-run/feedback") - assert response.status_code == 404 - assert "missing-run" in response.json()["detail"] - - def test_empty_list_when_no_feedback(self): - """Returns empty list when no feedback exists for the run.""" - run_record = {"run_id": "run-fb-2", "thread_id": "thread-fb-2"} - app = _make_app( - run_store=_make_run_store(run_record), - feedback_repo=_make_feedback_repo([]), - ) - with TestClient(app) as client: - response = client.get("/api/runs/run-fb-2/feedback") - assert response.status_code == 200 - assert response.json() == [] - - def test_503_when_feedback_repo_not_configured(self): - """Returns 503 when feedback_repo is None (no DB configured).""" - run_record = {"run_id": "run-fb-3", "thread_id": "thread-fb-3"} - app = _make_app( - run_store=_make_run_store(run_record), - ) - # Explicitly set feedback_repo to None to simulate missing DB - app.state.feedback_repo = None - with TestClient(app) as client: - response = client.get("/api/runs/run-fb-3/feedback") - assert response.status_code == 503 diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py deleted file mode 100644 index 3a6532567..000000000 --- a/backend/tests/test_thread_meta_repo.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Tests for ThreadMetaRepository (SQLAlchemy-backed).""" - -import pytest - -from deerflow.persistence.thread_meta import ThreadMetaRepository - - -async def _make_repo(tmp_path): - from deerflow.persistence.engine import get_session_factory, init_engine - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - return ThreadMetaRepository(get_session_factory()) - - -async def _cleanup(): - from deerflow.persistence.engine import close_engine - - await close_engine() - - -class TestThreadMetaRepository: - @pytest.mark.anyio - async def test_create_and_get(self, tmp_path): - repo = await _make_repo(tmp_path) - record = await repo.create("t1") - assert record["thread_id"] == "t1" - assert record["status"] == "idle" - assert "created_at" in record - - fetched = await repo.get("t1") - assert fetched is not None - assert fetched["thread_id"] == "t1" - await _cleanup() - - @pytest.mark.anyio - async def test_create_with_assistant_id(self, tmp_path): - repo = await _make_repo(tmp_path) - record = await repo.create("t1", assistant_id="agent1") - assert record["assistant_id"] == "agent1" - await _cleanup() - - @pytest.mark.anyio - async def test_create_with_owner_and_display_name(self, tmp_path): - repo = await _make_repo(tmp_path) - record = await repo.create("t1", user_id="user1", display_name="My Thread") - assert record["user_id"] == "user1" - assert record["display_name"] == "My Thread" - await _cleanup() - - @pytest.mark.anyio - async def test_create_with_metadata(self, tmp_path): - repo = await _make_repo(tmp_path) - record = await repo.create("t1", metadata={"key": "value"}) - assert record["metadata"] == {"key": "value"} - await _cleanup() - - @pytest.mark.anyio - async def test_get_nonexistent(self, tmp_path): - repo = await _make_repo(tmp_path) - assert await repo.get("nonexistent") is None - await _cleanup() - - @pytest.mark.anyio - async def test_check_access_no_record_allows(self, tmp_path): - repo = await _make_repo(tmp_path) - assert await repo.check_access("unknown", "user1") is True - await _cleanup() - - @pytest.mark.anyio - async def test_check_access_owner_matches(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1", user_id="user1") - assert await repo.check_access("t1", "user1") is True - await _cleanup() - - @pytest.mark.anyio - async def test_check_access_owner_mismatch(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1", user_id="user1") - assert await repo.check_access("t1", "user2") is False - await _cleanup() - - @pytest.mark.anyio - async def test_check_access_no_owner_allows_all(self, tmp_path): - repo = await _make_repo(tmp_path) - # Explicit user_id=None to bypass the new AUTO default that - # would otherwise pick up the test user from the autouse fixture. - await repo.create("t1", user_id=None) - assert await repo.check_access("t1", "anyone") is True - await _cleanup() - - @pytest.mark.anyio - async def test_check_access_strict_missing_row_denied(self, tmp_path): - """require_existing=True flips the missing-row case to *denied*. - - Closes the delete-idempotence cross-user gap: after a thread is - deleted, the row is gone, and the permissive default would let any - caller "claim" it as untracked. The strict mode demands a row. - """ - repo = await _make_repo(tmp_path) - assert await repo.check_access("never-existed", "user1", require_existing=True) is False - await _cleanup() - - @pytest.mark.anyio - async def test_check_access_strict_owner_match_allowed(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1", user_id="user1") - assert await repo.check_access("t1", "user1", require_existing=True) is True - await _cleanup() - - @pytest.mark.anyio - async def test_check_access_strict_owner_mismatch_denied(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1", user_id="user1") - assert await repo.check_access("t1", "user2", require_existing=True) is False - await _cleanup() - - @pytest.mark.anyio - async def test_check_access_strict_null_owner_still_allowed(self, tmp_path): - """Even in strict mode, a row with NULL user_id stays shared. - - The strict flag tightens the *missing row* case, not the *shared - row* case — legacy pre-auth rows that survived a clean migration - without an owner are still everyone's. - """ - repo = await _make_repo(tmp_path) - await repo.create("t1", user_id=None) - assert await repo.check_access("t1", "anyone", require_existing=True) is True - await _cleanup() - - @pytest.mark.anyio - async def test_update_status(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1") - await repo.update_status("t1", "busy") - record = await repo.get("t1") - assert record["status"] == "busy" - await _cleanup() - - @pytest.mark.anyio - async def test_delete(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1") - await repo.delete("t1") - assert await repo.get("t1") is None - await _cleanup() - - @pytest.mark.anyio - async def test_delete_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.delete("nonexistent") # should not raise - await _cleanup() - - @pytest.mark.anyio - async def test_update_metadata_merges(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1", metadata={"a": 1, "b": 2}) - await repo.update_metadata("t1", {"b": 99, "c": 3}) - record = await repo.get("t1") - # Existing key preserved, overlapping key overwritten, new key added - assert record["metadata"] == {"a": 1, "b": 99, "c": 3} - await _cleanup() - - @pytest.mark.anyio - async def test_update_metadata_on_empty(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1") - await repo.update_metadata("t1", {"k": "v"}) - record = await repo.get("t1") - assert record["metadata"] == {"k": "v"} - await _cleanup() - - @pytest.mark.anyio - async def test_update_metadata_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise - await _cleanup() diff --git a/backend/tests/test_user_context.py b/backend/tests/test_user_context.py deleted file mode 100644 index 8c7cbd13c..000000000 --- a/backend/tests/test_user_context.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Tests for runtime.user_context — contextvar three-state semantics. - -These tests opt out of the autouse contextvar fixture (added in -commit 6) because they explicitly test the cases where the contextvar -is set or unset. -""" - -from types import SimpleNamespace - -import pytest - -from deerflow.runtime.user_context import ( - CurrentUser, - DEFAULT_USER_ID, - get_current_user, - get_effective_user_id, - require_current_user, - reset_current_user, - set_current_user, -) - - -@pytest.mark.no_auto_user -def test_default_is_none(): - """Before any set, contextvar returns None.""" - assert get_current_user() is None - - -@pytest.mark.no_auto_user -def test_set_and_reset_roundtrip(): - """set_current_user returns a token that reset restores.""" - user = SimpleNamespace(id="user-1") - token = set_current_user(user) - try: - assert get_current_user() is user - finally: - reset_current_user(token) - assert get_current_user() is None - - -@pytest.mark.no_auto_user -def test_require_current_user_raises_when_unset(): - """require_current_user raises RuntimeError if contextvar is unset.""" - assert get_current_user() is None - with pytest.raises(RuntimeError, match="without user context"): - require_current_user() - - -@pytest.mark.no_auto_user -def test_require_current_user_returns_user_when_set(): - """require_current_user returns the user when contextvar is set.""" - user = SimpleNamespace(id="user-2") - token = set_current_user(user) - try: - assert require_current_user() is user - finally: - reset_current_user(token) - - -@pytest.mark.no_auto_user -def test_protocol_accepts_duck_typed(): - """CurrentUser is a runtime_checkable Protocol matching any .id-bearing object.""" - user = SimpleNamespace(id="user-3") - assert isinstance(user, CurrentUser) - - -@pytest.mark.no_auto_user -def test_protocol_rejects_no_id(): - """Objects without .id do not satisfy CurrentUser Protocol.""" - not_a_user = SimpleNamespace(email="no-id@example.com") - assert not isinstance(not_a_user, CurrentUser) - - -# --------------------------------------------------------------------------- -# get_effective_user_id / DEFAULT_USER_ID tests -# --------------------------------------------------------------------------- - - -def test_default_user_id_is_default(): - assert DEFAULT_USER_ID == "default" - - -@pytest.mark.no_auto_user -def test_effective_user_id_returns_default_when_no_user(): - """No user in context -> fallback to DEFAULT_USER_ID.""" - assert get_effective_user_id() == "default" - - -@pytest.mark.no_auto_user -def test_effective_user_id_returns_user_id_when_set(): - user = SimpleNamespace(id="u-abc-123") - token = set_current_user(user) - try: - assert get_effective_user_id() == "u-abc-123" - finally: - reset_current_user(token) - - -@pytest.mark.no_auto_user -def test_effective_user_id_coerces_to_str(): - """User.id might be a UUID object; must come back as str.""" - import uuid - uid = uuid.uuid4() - - user = SimpleNamespace(id=uid) - token = set_current_user(user) - try: - assert get_effective_user_id() == str(uid) - finally: - reset_current_user(token) diff --git a/backend/tests/_router_auth_helpers.py b/backend/tests/unittest/_router_auth_helpers.py similarity index 65% rename from backend/tests/_router_auth_helpers.py rename to backend/tests/unittest/_router_auth_helpers.py index a7ce60468..99e01ad20 100644 --- a/backend/tests/_router_auth_helpers.py +++ b/backend/tests/unittest/_router_auth_helpers.py @@ -1,12 +1,9 @@ -"""Helpers for router-level tests that need a stubbed auth context. +"""Helpers for router-level tests that need an authenticated request. -The production gateway runs ``AuthMiddleware`` (validates the JWT cookie) -ahead of every router, plus ``@require_permission(owner_check=True)`` -decorators that read ``request.state.auth`` and call -``thread_store.check_access``. Router-level unit tests construct -**bare** FastAPI apps that include only one router — they have neither -the auth middleware nor a real thread_store, so the decorators raise -401 (TestClient path) or ValueError (direct-call path). +The production gateway stamps ``request.user`` / ``request.auth`` in the +auth middleware, then route decorators read that authenticated context. +Router-level unit tests build very small FastAPI apps that include only +one router, so they need a lightweight stand-in for that middleware. This module provides two surfaces: @@ -15,10 +12,9 @@ This module provides two surfaces: request, plus a permissive ``thread_store`` mock on ``app.state``. Use from TestClient-based router tests. -2. :func:`call_unwrapped` — invokes the underlying function bypassing - the ``@require_permission`` decorator chain by walking ``__wrapped__``. - Use from direct-call tests that previously imported the route - function and called it positionally. +2. :func:`call_unwrapped` — invokes the underlying function by walking + ``__wrapped__``. Use from direct-call tests that want to bypass the + route decorators entirely. Both helpers are deliberately permissive: they never deny a request. Tests that want to verify the *auth boundary itself* (e.g. @@ -37,8 +33,8 @@ from fastapi import FastAPI, Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp -from app.gateway.auth.models import User -from app.gateway.authz import AuthContext, Permissions +from app.plugins.auth.domain.models import User +from app.plugins.auth.authorization import AuthContext, Permissions # Default permission set granted to the stub user. Mirrors `_ALL_PERMISSIONS` # in authz.py — kept inline so the tests don't import a private symbol. @@ -63,12 +59,7 @@ def _make_stub_user() -> User: class _StubAuthMiddleware(BaseHTTPMiddleware): - """Stamp a fake user / AuthContext onto every request. - - Mirrors what production ``AuthMiddleware`` does after the JWT decode - + DB lookup short-circuit, so ``@require_permission`` finds an - authenticated context and skips its own re-authentication path. - """ + """Stamp a fake user / AuthContext onto every request.""" def __init__(self, app: ASGIApp, user_factory: Callable[[], User]) -> None: super().__init__(app) @@ -76,8 +67,11 @@ class _StubAuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: Callable) -> Response: user = self._user_factory() + auth_context = AuthContext(user=user, permissions=list(_STUB_PERMISSIONS)) + request.scope["user"] = user + request.scope["auth"] = auth_context request.state.user = user - request.state.auth = AuthContext(user=user, permissions=list(_STUB_PERMISSIONS)) + request.state.auth = auth_context return await call_next(request) @@ -93,9 +87,8 @@ def make_authed_test_app( populated :class:`User`. Useful for cross-user isolation tests that need a stable id across requests. owner_check_passes: When True (default), ``thread_store.check_access`` - returns True for every call so ``@require_permission(owner_check=True)`` - never blocks the route under test. Pass False to verify that - permission failures surface correctly. + returns True for every call so owner-gated routes do not block + the handler under test. Pass False to verify denial paths. Returns: A ``FastAPI`` app with the stub middleware installed and @@ -121,12 +114,9 @@ def call_unwrapped(decorated: Callable[_P, _R], /, *args: _P.args, **kwargs: _P. """Invoke the underlying function of a ``@require_permission``-decorated route. ``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all - the way down to the original handler, bypassing every authz + - require_auth wrapper. Use from tests that need to call route - functions directly (without TestClient) and don't want to construct - a fake ``Request`` just to satisfy the decorator. The ``ParamSpec`` - propagates the wrapped route's signature so call sites still get - parameter checking despite the unwrapping. + the way down to the original handler. Use from tests that call route + functions directly and do not want to build a full request/middleware + stack. """ fn: Callable = decorated while hasattr(fn, "__wrapped__"): diff --git a/backend/tests/test_acp_config.py b/backend/tests/unittest/test_acp_config.py similarity index 100% rename from backend/tests/test_acp_config.py rename to backend/tests/unittest/test_acp_config.py diff --git a/backend/tests/test_aio_sandbox.py b/backend/tests/unittest/test_aio_sandbox.py similarity index 100% rename from backend/tests/test_aio_sandbox.py rename to backend/tests/unittest/test_aio_sandbox.py diff --git a/backend/tests/test_aio_sandbox_local_backend.py b/backend/tests/unittest/test_aio_sandbox_local_backend.py similarity index 100% rename from backend/tests/test_aio_sandbox_local_backend.py rename to backend/tests/unittest/test_aio_sandbox_local_backend.py diff --git a/backend/tests/test_aio_sandbox_provider.py b/backend/tests/unittest/test_aio_sandbox_provider.py similarity index 100% rename from backend/tests/test_aio_sandbox_provider.py rename to backend/tests/unittest/test_aio_sandbox_provider.py diff --git a/backend/tests/test_app_config_reload.py b/backend/tests/unittest/test_app_config_reload.py similarity index 100% rename from backend/tests/test_app_config_reload.py rename to backend/tests/unittest/test_app_config_reload.py diff --git a/backend/tests/test_artifacts_router.py b/backend/tests/unittest/test_artifacts_router.py similarity index 100% rename from backend/tests/test_artifacts_router.py rename to backend/tests/unittest/test_artifacts_router.py diff --git a/backend/tests/test_auth.py b/backend/tests/unittest/test_auth.py similarity index 70% rename from backend/tests/test_auth.py rename to backend/tests/unittest/test_auth.py index ea4c5733a..e2f3e271d 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/unittest/test_auth.py @@ -1,22 +1,22 @@ -"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators.""" +"""Tests for authentication module: JWT, password hashing, and auth context behavior.""" from datetime import timedelta from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 import pytest -from fastapi import FastAPI, HTTPException -from fastapi.testclient import TestClient +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password -from app.gateway.auth.models import User -from app.gateway.authz import ( +from app.plugins.auth.authorization import ( AuthContext, Permissions, get_auth_context, - require_auth, - require_permission, ) +from app.plugins.auth.authorization.hooks import build_authz_hooks +from app.plugins.auth.domain import create_access_token, decode_token, hash_password, verify_password +from app.plugins.auth.domain.models import User +from store.persistence import MappedBase # ── Password Hashing ──────────────────────────────────────────────────────── @@ -67,7 +67,7 @@ def test_create_and_decode_token(): def test_decode_token_expired(): """Expired token returns TokenError.EXPIRED.""" - from app.gateway.auth.errors import TokenError + from app.plugins.auth.domain.errors import TokenError user_id = str(uuid4()) # Create token that expires immediately @@ -78,7 +78,7 @@ def test_decode_token_expired(): def test_decode_token_invalid(): """Invalid token returns TokenError.""" - from app.gateway.auth.errors import TokenError + from app.plugins.auth.domain.errors import TokenError assert isinstance(decode_token("not.a.valid.token"), TokenError) assert isinstance(decode_token(""), TokenError) @@ -101,6 +101,8 @@ def test_auth_context_unauthenticated(): """AuthContext with no user.""" ctx = AuthContext(user=None, permissions=[]) assert ctx.is_authenticated is False + assert ctx.principal_id is None + assert ctx.capabilities == () assert ctx.has_permission("threads", "read") is False @@ -109,6 +111,8 @@ def test_auth_context_authenticated_no_perms(): user = User(id=uuid4(), email="test@example.com", password_hash="hash") ctx = AuthContext(user=user, permissions=[]) assert ctx.is_authenticated is True + assert ctx.principal_id == str(user.id) + assert ctx.capabilities == () assert ctx.has_permission("threads", "read") is False @@ -117,6 +121,7 @@ def test_auth_context_has_permission(): user = User(id=uuid4(), email="test@example.com", password_hash="hash") perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE] ctx = AuthContext(user=user, permissions=perms) + assert ctx.capabilities == tuple(perms) assert ctx.has_permission("threads", "read") is True assert ctx.has_permission("threads", "write") is True assert ctx.has_permission("threads", "delete") is False @@ -162,79 +167,12 @@ def test_get_auth_context_set(): assert get_auth_context(mock_request) == ctx -# ── require_auth decorator ────────────────────────────────────────────────── +def test_register_app_sets_default_authz_hooks(): + from app.gateway.registrar import register_app + app = register_app() -def test_require_auth_sets_auth_context(): - """require_auth sets auth context on request from cookie.""" - from fastapi import Request - - app = FastAPI() - - @app.get("/test") - @require_auth - async def endpoint(request: Request): - ctx = get_auth_context(request) - return {"authenticated": ctx.is_authenticated} - - with TestClient(app) as client: - # No cookie → anonymous - response = client.get("/test") - assert response.status_code == 200 - assert response.json()["authenticated"] is False - - -def test_require_auth_requires_request_param(): - """require_auth raises ValueError if request parameter is missing.""" - import asyncio - - @require_auth - async def bad_endpoint(): # Missing `request` parameter - pass - - with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"): - asyncio.run(bad_endpoint()) - - -# ── require_permission decorator ───────────────────────────────────────────── - - -def test_require_permission_requires_auth(): - """require_permission raises 401 when not authenticated.""" - from fastapi import Request - - app = FastAPI() - - @app.get("/test") - @require_permission("threads", "read") - async def endpoint(request: Request): - return {"ok": True} - - with TestClient(app) as client: - response = client.get("/test") - assert response.status_code == 401 - assert "Authentication required" in response.json()["detail"] - - -def test_require_permission_denies_wrong_permission(): - """User without required permission gets 403.""" - from fastapi import Request - - app = FastAPI() - user = User(id=uuid4(), email="test@example.com", password_hash="hash") - - @app.get("/test") - @require_permission("threads", "delete") - async def endpoint(request: Request): - return {"ok": True} - - mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ]) - - with patch("app.gateway.authz._authenticate", return_value=mock_auth): - with TestClient(app) as client: - response = client.get("/test") - assert response.status_code == 403 - assert "Permission denied" in response.json()["detail"] + assert app.state.authz_hooks == build_authz_hooks() # ── Weak JWT secret warning ────────────────────────────────────────────────── @@ -271,45 +209,55 @@ def test_sqlite_round_trip_new_fields(): import asyncio import tempfile - from app.gateway.auth.repositories.sqlite import SQLiteUserRepository + from app.plugins.auth.storage import DbUserRepository, UserCreate async def _run() -> None: - from deerflow.persistence.engine import ( - close_engine, - get_session_factory, - init_engine, - ) - with tempfile.TemporaryDirectory() as tmpdir: - url = f"sqlite+aiosqlite:///{tmpdir}/scratch.db" - await init_engine("sqlite", url=url, sqlite_dir=tmpdir) + engine = create_async_engine(f"sqlite+aiosqlite:///{tmpdir}/scratch.db", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) try: - repo = SQLiteUserRepository(get_session_factory()) - user = User( - email="setup@test.com", - password_hash="fakehash", - system_role="admin", - needs_setup=True, - token_version=3, - ) - created = await repo.create_user(user) + async with session_factory() as session: + repo = DbUserRepository(session) + created = await repo.create_user( + UserCreate( + email="setup@test.com", + password_hash="fakehash", + system_role="admin", + needs_setup=True, + token_version=3, + ) + ) + await session.commit() assert created.needs_setup is True assert created.token_version == 3 - fetched = await repo.get_user_by_email("setup@test.com") + async with session_factory() as session: + repo = DbUserRepository(session) + fetched = await repo.get_user_by_email("setup@test.com") assert fetched is not None assert fetched.needs_setup is True assert fetched.token_version == 3 - fetched.needs_setup = False - fetched.token_version = 4 - await repo.update_user(fetched) - refetched = await repo.get_user_by_id(str(fetched.id)) + updated = fetched.model_copy(update={"needs_setup": False, "token_version": 4}) + async with session_factory() as session: + repo = DbUserRepository(session) + await repo.update_user(updated) + await session.commit() + async with session_factory() as session: + repo = DbUserRepository(session) + refetched = await repo.get_user_by_id(fetched.id) assert refetched is not None assert refetched.needs_setup is False assert refetched.token_version == 4 finally: - await close_engine() + await engine.dispose() asyncio.run(_run()) @@ -320,49 +268,54 @@ def test_update_user_raises_when_row_concurrently_deleted(tmp_path): Earlier the SQLite repo returned the input unchanged when the row was missing, making a phantom success path that admin password reset callers (`reset_admin`, `_ensure_admin_user`) would happily log as - 'password reset'. The new contract: raise ``UserNotFoundError`` so + 'password reset'. The new contract: raise ``LookupError`` so a vanished row never looks like a successful update. """ import asyncio import tempfile - from app.gateway.auth.repositories.base import UserNotFoundError - from app.gateway.auth.repositories.sqlite import SQLiteUserRepository + from app.plugins.auth.storage import DbUserRepository, UserCreate async def _run() -> None: - from deerflow.persistence.engine import ( - close_engine, - get_session_factory, - init_engine, - ) - from deerflow.persistence.user.model import UserRow + from app.plugins.auth.storage.models import User as UserModel with tempfile.TemporaryDirectory() as d: - url = f"sqlite+aiosqlite:///{d}/scratch.db" - await init_engine("sqlite", url=url, sqlite_dir=d) + engine = create_async_engine(f"sqlite+aiosqlite:///{d}/scratch.db", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + sf = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) try: - sf = get_session_factory() - repo = SQLiteUserRepository(sf) - user = User( - email="ghost@test.com", - password_hash="fakehash", - system_role="user", - ) - created = await repo.create_user(user) + async with sf() as session: + repo = DbUserRepository(session) + created = await repo.create_user( + UserCreate( + email="ghost@test.com", + password_hash="fakehash", + system_role="user", + ) + ) + await session.commit() # Simulate "row vanished underneath us" by deleting the row # via the raw ORM session, then attempt to update. async with sf() as session: - row = await session.get(UserRow, str(created.id)) + row = await session.get(UserModel, created.id) assert row is not None await session.delete(row) await session.commit() - created.needs_setup = True - with pytest.raises(UserNotFoundError): - await repo.update_user(created) + updated = created.model_copy(update={"needs_setup": True}) + async with sf() as session: + repo = DbUserRepository(session) + with pytest.raises(LookupError): + await repo.update_user(updated) finally: - await close_engine() + await engine.dispose() asyncio.run(_run()) @@ -374,7 +327,7 @@ def test_jwt_encodes_ver(): """JWT payload includes ver field.""" import os - from app.gateway.auth.errors import TokenError + from app.plugins.auth.domain.errors import TokenError os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars" token = create_access_token(str(uuid4()), token_version=3) @@ -387,7 +340,7 @@ def test_jwt_default_ver_zero(): """JWT ver defaults to 0.""" import os - from app.gateway.auth.errors import TokenError + from app.plugins.auth.domain.errors import TokenError os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars" token = create_access_token(str(uuid4())) @@ -398,30 +351,34 @@ def test_jwt_default_ver_zero(): def test_token_version_mismatch_rejects(): """Token with stale ver is rejected by get_current_user_from_request.""" - import asyncio import os + from types import SimpleNamespace + from app.plugins.auth.security.dependencies import get_current_user_from_request os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars" user_id = str(uuid4()) token = create_access_token(user_id, token_version=0) + request = SimpleNamespace( + cookies={"access_token": token}, + state=SimpleNamespace( + _auth_session=MagicMock(), + ), + ) + stale_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1) + request.state._auth_session.__aenter__ = AsyncMock(return_value=request.state._auth_session) + request.state._auth_session.__aexit__ = AsyncMock(return_value=None) - mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1) - - mock_request = MagicMock() - mock_request.cookies = {"access_token": token} - - with patch("app.gateway.deps.get_local_provider") as mock_provider_fn: - mock_provider = MagicMock() - mock_provider.get_user = AsyncMock(return_value=mock_user) - mock_provider_fn.return_value = mock_provider - - from app.gateway.deps import get_current_user_from_request - + with patch( + "app.plugins.auth.security.dependencies.DbUserRepository.get_user_by_id", + new=AsyncMock(return_value=stale_user), + ): with pytest.raises(HTTPException) as exc_info: - asyncio.run(get_current_user_from_request(mock_request)) - assert exc_info.value.status_code == 401 - assert "revoked" in str(exc_info.value.detail).lower() + import asyncio + + asyncio.run(get_current_user_from_request(request)) + assert exc_info.value.status_code == 401 + assert "revoked" in str(exc_info.value.detail).lower() # ── change-password extension ────────────────────────────────────────────── @@ -429,7 +386,7 @@ def test_token_version_mismatch_rejects(): def test_change_password_request_accepts_new_email(): """ChangePasswordRequest model accepts optional new_email.""" - from app.gateway.routers.auth import ChangePasswordRequest + from app.plugins.auth.api.schemas import ChangePasswordRequest req = ChangePasswordRequest( current_password="old", @@ -441,7 +398,7 @@ def test_change_password_request_accepts_new_email(): def test_change_password_request_new_email_optional(): """ChangePasswordRequest model works without new_email.""" - from app.gateway.routers.auth import ChangePasswordRequest + from app.plugins.auth.api.schemas import ChangePasswordRequest req = ChangePasswordRequest(current_password="old", new_password="newpassword") assert req.new_email is None @@ -449,7 +406,7 @@ def test_change_password_request_new_email_optional(): def test_login_response_includes_needs_setup(): """LoginResponse includes needs_setup field.""" - from app.gateway.routers.auth import LoginResponse + from app.plugins.auth.api.schemas import LoginResponse resp = LoginResponse(expires_in=3600, needs_setup=True) assert resp.needs_setup is True @@ -462,7 +419,7 @@ def test_login_response_includes_needs_setup(): def test_rate_limiter_allows_under_limit(): """Requests under the limit are allowed.""" - from app.gateway.routers.auth import _check_rate_limit, _login_attempts + from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts _login_attempts.clear() _check_rate_limit("192.168.1.1") # Should not raise @@ -470,7 +427,7 @@ def test_rate_limiter_allows_under_limit(): def test_rate_limiter_blocks_after_max_failures(): """IP is blocked after 5 consecutive failures.""" - from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure + from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts, _record_login_failure _login_attempts.clear() ip = "10.0.0.1" @@ -483,7 +440,7 @@ def test_rate_limiter_blocks_after_max_failures(): def test_rate_limiter_resets_on_success(): """Successful login clears the failure counter.""" - from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success + from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success _login_attempts.clear() ip = "10.0.0.2" @@ -499,7 +456,7 @@ def test_rate_limiter_resets_on_success(): def test_get_client_ip_direct_connection_no_proxy(monkeypatch): """Direct mode (no AUTH_TRUSTED_PROXIES): use TCP peer regardless of X-Real-IP.""" monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False) - from app.gateway.routers.auth import _get_client_ip + from app.plugins.auth.api.schemas import _get_client_ip req = MagicMock() req.client.host = "203.0.113.42" @@ -514,7 +471,7 @@ def test_get_client_ip_x_real_ip_ignored_when_no_trusted_proxy(monkeypatch): request to dodge per-IP rate limits in dev / direct mode. """ monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False) - from app.gateway.routers.auth import _get_client_ip + from app.plugins.auth.api.schemas import _get_client_ip req = MagicMock() req.client.host = "127.0.0.1" @@ -525,7 +482,7 @@ def test_get_client_ip_x_real_ip_ignored_when_no_trusted_proxy(monkeypatch): def test_get_client_ip_x_real_ip_honored_from_trusted_proxy(monkeypatch): """X-Real-IP is honored when the TCP peer matches AUTH_TRUSTED_PROXIES.""" monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8") - from app.gateway.routers.auth import _get_client_ip + from app.plugins.auth.api.schemas import _get_client_ip req = MagicMock() req.client.host = "10.5.6.7" # in trusted CIDR @@ -536,7 +493,7 @@ def test_get_client_ip_x_real_ip_honored_from_trusted_proxy(monkeypatch): def test_get_client_ip_x_real_ip_rejected_from_untrusted_peer(monkeypatch): """X-Real-IP is rejected when the TCP peer is NOT in the trusted list.""" monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8") - from app.gateway.routers.auth import _get_client_ip + from app.plugins.auth.api.schemas import _get_client_ip req = MagicMock() req.client.host = "8.8.8.8" # NOT in trusted CIDR @@ -547,7 +504,7 @@ def test_get_client_ip_x_real_ip_rejected_from_untrusted_peer(monkeypatch): def test_get_client_ip_xff_never_honored(monkeypatch): """X-Forwarded-For is never used; only X-Real-IP from a trusted peer.""" monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8") - from app.gateway.routers.auth import _get_client_ip + from app.plugins.auth.api.schemas import _get_client_ip req = MagicMock() req.client.host = "10.0.0.1" @@ -558,7 +515,7 @@ def test_get_client_ip_xff_never_honored(monkeypatch): def test_get_client_ip_invalid_trusted_proxy_entry_skipped(monkeypatch, caplog): """Garbage entries in AUTH_TRUSTED_PROXIES are warned and skipped.""" monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "not-an-ip,10.0.0.0/8") - from app.gateway.routers.auth import _get_client_ip + from app.plugins.auth.api.schemas import _get_client_ip req = MagicMock() req.client.host = "10.5.6.7" @@ -569,7 +526,7 @@ def test_get_client_ip_invalid_trusted_proxy_entry_skipped(monkeypatch, caplog): def test_get_client_ip_no_client_returns_unknown(monkeypatch): """No request.client → 'unknown' marker (no crash).""" monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False) - from app.gateway.routers.auth import _get_client_ip + from app.plugins.auth.api.schemas import _get_client_ip req = MagicMock() req.client = None @@ -584,7 +541,7 @@ def test_register_rejects_literal_password(): """Pydantic validator rejects 'password' as a registration password.""" from pydantic import ValidationError - from app.gateway.routers.auth import RegisterRequest + from app.plugins.auth.api.schemas import RegisterRequest with pytest.raises(ValidationError) as exc: RegisterRequest(email="x@example.com", password="password") @@ -595,7 +552,7 @@ def test_register_rejects_common_password_case_insensitive(): """Case variants of common passwords are also rejected.""" from pydantic import ValidationError - from app.gateway.routers.auth import RegisterRequest + from app.plugins.auth.api.schemas import RegisterRequest for variant in ["PASSWORD", "Password1", "qwerty123", "letmein1"]: with pytest.raises(ValidationError): @@ -604,7 +561,7 @@ def test_register_rejects_common_password_case_insensitive(): def test_register_accepts_strong_password(): """A non-blocklisted password of length >=8 is accepted.""" - from app.gateway.routers.auth import RegisterRequest + from app.plugins.auth.api.schemas import RegisterRequest req = RegisterRequest(email="x@example.com", password="Tr0ub4dor&3-Horse") assert req.password == "Tr0ub4dor&3-Horse" @@ -614,7 +571,7 @@ def test_change_password_rejects_common_password(): """The same blocklist applies to change-password.""" from pydantic import ValidationError - from app.gateway.routers.auth import ChangePasswordRequest + from app.plugins.auth.api.schemas import ChangePasswordRequest with pytest.raises(ValidationError): ChangePasswordRequest(current_password="anything", new_password="iloveyou") @@ -624,7 +581,7 @@ def test_password_blocklist_keeps_short_passwords_for_length_check(): """Short passwords still fail the min_length check (not the blocklist).""" from pydantic import ValidationError - from app.gateway.routers.auth import RegisterRequest + from app.plugins.auth.api.schemas import RegisterRequest with pytest.raises(ValidationError) as exc: RegisterRequest(email="x@example.com", password="abc") @@ -639,16 +596,17 @@ def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog): """get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset.""" import logging - import app.gateway.auth.config as config_module + import app.plugins.auth.runtime.config_state as config_module + from app.plugins.auth.runtime.config_state import reset_auth_config - config_module._auth_config = None monkeypatch.delenv("AUTH_JWT_SECRET", raising=False) with caplog.at_level(logging.WARNING): + reset_auth_config() config = config_module.get_auth_config() assert config.jwt_secret # non-empty ephemeral secret assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages) # Cleanup - config_module._auth_config = None + reset_auth_config() diff --git a/backend/tests/test_auth_config.py b/backend/tests/unittest/test_auth_config.py similarity index 77% rename from backend/tests/test_auth_config.py rename to backend/tests/unittest/test_auth_config.py index 21b8bd81b..fd830deb0 100644 --- a/backend/tests/test_auth_config.py +++ b/backend/tests/unittest/test_auth_config.py @@ -5,7 +5,8 @@ from unittest.mock import patch import pytest -from app.gateway.auth.config import AuthConfig +from app.plugins.auth.domain.config import AuthConfig +from app.plugins.auth.runtime.config_state import reset_auth_config def test_auth_config_defaults(): @@ -25,30 +26,28 @@ def test_auth_config_token_expiry_range(): def test_auth_config_from_env(): env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"} with patch.dict(os.environ, env, clear=False): - import app.gateway.auth.config as cfg + import app.plugins.auth.runtime.config_state as cfg - old = cfg._auth_config - cfg._auth_config = None try: + reset_auth_config() config = cfg.get_auth_config() assert config.jwt_secret == "test-jwt-secret-from-env" finally: - cfg._auth_config = old + reset_auth_config() def test_auth_config_missing_secret_generates_ephemeral(caplog): import logging - import app.gateway.auth.config as cfg + import app.plugins.auth.runtime.config_state as cfg - old = cfg._auth_config - cfg._auth_config = None try: with patch.dict(os.environ, {}, clear=True): os.environ.pop("AUTH_JWT_SECRET", None) with caplog.at_level(logging.WARNING): + reset_auth_config() config = cfg.get_auth_config() assert config.jwt_secret assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages) finally: - cfg._auth_config = old + reset_auth_config() diff --git a/backend/tests/unittest/test_auth_dependencies.py b/backend/tests/unittest/test_auth_dependencies.py new file mode 100644 index 000000000..0e6448b12 --- /dev/null +++ b/backend/tests/unittest/test_auth_dependencies.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.plugins.auth.domain.config import AuthConfig +from app.plugins.auth.security.dependencies import ( + get_current_user_from_request, + get_current_user_id, + get_optional_user_from_request, +) +from app.plugins.auth.domain.jwt import create_access_token +from app.plugins.auth.runtime.config_state import set_auth_config +from app.plugins.auth.storage import DbUserRepository, UserCreate +from store.persistence import MappedBase +from app.plugins.auth.storage.models import User as UserModel # noqa: F401 + +_TEST_SECRET = "test-secret-auth-dependencies-min-32" + + +@pytest.fixture(autouse=True) +def _setup_auth_config(): + set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) + yield + set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) + + +async def _make_request(tmp_path, *, cookie: str | None = None, users: list[UserCreate] | None = None): + engine = create_async_engine( + f"sqlite+aiosqlite:///{tmp_path / 'auth-deps.db'}", + future=True, + ) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + session = session_factory() + if users: + repo = DbUserRepository(session) + for user in users: + await repo.create_user(user) + await session.commit() + request = SimpleNamespace( + cookies={"access_token": cookie} if cookie is not None else {}, + state=SimpleNamespace(_auth_session=session), + ) + return request, session, engine + + +class TestAuthDependencies: + @pytest.mark.anyio + async def test_no_cookie_returns_401(self, tmp_path): + request, session, engine = await _make_request(tmp_path) + try: + with pytest.raises(HTTPException) as exc_info: + await get_current_user_from_request(request) + finally: + await session.close() + await engine.dispose() + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail["code"] == "not_authenticated" + + @pytest.mark.anyio + async def test_invalid_token_returns_401(self, tmp_path): + request, session, engine = await _make_request(tmp_path, cookie="garbage") + try: + with pytest.raises(HTTPException) as exc_info: + await get_current_user_from_request(request) + finally: + await session.close() + await engine.dispose() + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail["code"] == "token_invalid" + + @pytest.mark.anyio + async def test_missing_user_returns_401(self, tmp_path): + token = create_access_token("missing-user", token_version=0) + request, session, engine = await _make_request(tmp_path, cookie=token) + try: + with pytest.raises(HTTPException) as exc_info: + await get_current_user_from_request(request) + finally: + await session.close() + await engine.dispose() + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail["code"] == "user_not_found" + + @pytest.mark.anyio + async def test_token_version_mismatch_returns_401(self, tmp_path): + token = create_access_token("user-1", token_version=0) + request, session, engine = await _make_request( + tmp_path, + cookie=token, + users=[ + UserCreate( + id="user-1", + email="user1@example.com", + token_version=2, + ) + ], + ) + try: + with pytest.raises(HTTPException) as exc_info: + await get_current_user_from_request(request) + finally: + await session.close() + await engine.dispose() + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail["code"] == "token_invalid" + + @pytest.mark.anyio + async def test_valid_token_returns_user(self, tmp_path): + token = create_access_token("user-2", token_version=3) + request, session, engine = await _make_request( + tmp_path, + cookie=token, + users=[ + UserCreate( + id="user-2", + email="user2@example.com", + token_version=3, + ) + ], + ) + try: + user = await get_current_user_from_request(request) + user_id = await get_current_user_id(request) + finally: + await session.close() + await engine.dispose() + + assert user.id == "user-2" + assert user.email == "user2@example.com" + assert user_id == "user-2" + + @pytest.mark.anyio + async def test_optional_user_returns_none_on_failure(self, tmp_path): + request, session, engine = await _make_request(tmp_path, cookie="bad-token") + try: + user = await get_optional_user_from_request(request) + finally: + await session.close() + await engine.dispose() + + assert user is None diff --git a/backend/tests/test_auth_errors.py b/backend/tests/unittest/test_auth_errors.py similarity index 85% rename from backend/tests/test_auth_errors.py rename to backend/tests/unittest/test_auth_errors.py index b3b46c75f..fffd36e22 100644 --- a/backend/tests/test_auth_errors.py +++ b/backend/tests/unittest/test_auth_errors.py @@ -4,9 +4,10 @@ from datetime import UTC, datetime, timedelta import jwt as pyjwt -from app.gateway.auth.config import AuthConfig, set_auth_config -from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError -from app.gateway.auth.jwt import create_access_token, decode_token +from app.plugins.auth.domain.config import AuthConfig +from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse, TokenError +from app.plugins.auth.domain.jwt import create_access_token, decode_token +from app.plugins.auth.runtime.config_state import set_auth_config def test_auth_error_code_values(): @@ -56,7 +57,7 @@ def test_decode_token_returns_token_error_on_expired(): def test_decode_token_returns_token_error_on_bad_signature(): _setup_config() payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)} - token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256") + token = pyjwt.encode(payload, "wrong-secret-key-for-tests-minimum-32", algorithm="HS256") result = decode_token(token) assert result == TokenError.INVALID_SIGNATURE diff --git a/backend/tests/test_auth_middleware.py b/backend/tests/unittest/test_auth_middleware.py similarity index 81% rename from backend/tests/test_auth_middleware.py rename to backend/tests/unittest/test_auth_middleware.py index 398f9cec6..b04fdc802 100644 --- a/backend/tests/test_auth_middleware.py +++ b/backend/tests/unittest/test_auth_middleware.py @@ -1,9 +1,11 @@ """Tests for the global AuthMiddleware (fail-closed safety net).""" +from types import SimpleNamespace + import pytest from starlette.testclient import TestClient -from app.gateway.auth_middleware import AuthMiddleware, _is_public +from app.plugins.auth.security.middleware import AuthMiddleware, _is_public # ── _is_public unit tests ───────────────────────────────────────────────── @@ -165,7 +167,8 @@ def test_protected_path_with_junk_cookie_rejected(client): """Junk cookie → 401. Middleware strictly validates the JWT now (AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad tokens through to the route handler.""" - res = client.get("/api/models", cookies={"access_token": "some-token"}) + client.cookies.set("access_token", "some-token") + res = client.get("/api/models") assert res.status_code == 401 @@ -220,3 +223,44 @@ def test_unknown_endpoint_with_junk_cookie_rejected(client): client.cookies.set("access_token", "tok") res = client.get("/api/future-endpoint") assert res.status_code == 401 + + +def test_middleware_populates_request_user_and_auth(monkeypatch): + from fastapi import FastAPI, Request + + from app.plugins.auth.security import middleware as middleware_module + + user = SimpleNamespace(id="user-123", email="test@example.com") + + async def _fake_get_current_user_from_request(request): + return user + + monkeypatch.setattr( + middleware_module, + "get_current_user_from_request", + _fake_get_current_user_from_request, + ) + + app = FastAPI() + app.add_middleware(AuthMiddleware) + + @app.get("/api/models") + async def models_get(request: Request): + return { + "request_user_id": request.user.id, + "state_user_id": request.state.user.id, + "request_auth_user_id": request.auth.user.id, + "state_auth_user_id": request.state.auth.user.id, + } + + client = TestClient(app) + client.cookies.set("access_token", "valid") + response = client.get("/api/models") + + assert response.status_code == 200 + assert response.json() == { + "request_user_id": "user-123", + "state_user_id": "user-123", + "request_auth_user_id": "user-123", + "state_auth_user_id": "user-123", + } diff --git a/backend/tests/unittest/test_auth_policies.py b/backend/tests/unittest/test_auth_policies.py new file mode 100644 index 000000000..ca113b770 --- /dev/null +++ b/backend/tests/unittest/test_auth_policies.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from starlette.requests import Request +from unittest.mock import AsyncMock + +from app.plugins.auth.authorization import AuthContext, Permissions +from app.plugins.auth.authorization.policies import require_thread_owner +from app.plugins.auth.domain.models import User + + +def _make_auth_context() -> AuthContext: + user = User(id=uuid4(), email="user@example.com", password_hash="hash") + return AuthContext(user=user, permissions=[Permissions.THREADS_READ, Permissions.RUNS_READ]) + + +def _make_request(*, thread_repo, run_repo=None, checkpointer=None) -> Request: + app = SimpleNamespace( + state=SimpleNamespace( + thread_meta_repo=thread_repo, + run_store=run_repo, + checkpointer=checkpointer, + ) + ) + scope = { + "type": "http", + "method": "GET", + "path": "/api/threads/thread-1/runs", + "headers": [], + "app": app, + "route": SimpleNamespace(path="/api/threads/{thread_id}/runs"), + "path_params": {"thread_id": "thread-1"}, + } + return Request(scope) + + +@pytest.mark.anyio +async def test_require_thread_owner_uses_thread_row_user_id() -> None: + auth = _make_auth_context() + thread_repo = SimpleNamespace( + get_thread_meta=AsyncMock( + return_value=SimpleNamespace( + user_id=str(auth.user.id), + metadata={"user_id": "someone-else"}, + ) + ) + ) + request = _make_request(thread_repo=thread_repo) + + await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True) + + +@pytest.mark.anyio +async def test_require_thread_owner_falls_back_to_user_owned_runs() -> None: + auth = _make_auth_context() + thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None)) + run_repo = SimpleNamespace( + list_by_thread=AsyncMock(return_value=[{"run_id": "run-1", "thread_id": "thread-1"}]) + ) + request = _make_request(thread_repo=thread_repo, run_repo=run_repo) + + await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True) + + run_repo.list_by_thread.assert_awaited_once_with("thread-1", limit=1, user_id=str(auth.user.id)) + + +@pytest.mark.anyio +async def test_require_thread_owner_falls_back_to_checkpoint_threads() -> None: + auth = _make_auth_context() + thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None)) + run_repo = SimpleNamespace(list_by_thread=AsyncMock(return_value=[])) + checkpointer = SimpleNamespace(aget_tuple=AsyncMock(return_value=object())) + request = _make_request(thread_repo=thread_repo, run_repo=run_repo, checkpointer=checkpointer) + + await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True) + + checkpointer.aget_tuple.assert_awaited_once_with( + {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + ) + + +@pytest.mark.anyio +async def test_require_thread_owner_denies_missing_thread() -> None: + auth = _make_auth_context() + thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None)) + run_repo = SimpleNamespace(list_by_thread=AsyncMock(return_value=[])) + checkpointer = SimpleNamespace(aget_tuple=AsyncMock(return_value=None)) + request = _make_request(thread_repo=thread_repo, run_repo=run_repo, checkpointer=checkpointer) + + with pytest.raises(Exception) as exc_info: + await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True) + + assert getattr(exc_info.value, "status_code", None) == 404 + assert getattr(exc_info.value, "detail", "") == "Thread thread-1 not found" diff --git a/backend/tests/unittest/test_auth_route_injection.py b/backend/tests/unittest/test_auth_route_injection.py new file mode 100644 index 000000000..703bf19e0 --- /dev/null +++ b/backend/tests/unittest/test_auth_route_injection.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from fastapi import APIRouter, FastAPI +from starlette.requests import Request + +from app.plugins.auth.authorization import AuthContext +from app.plugins.auth.domain.models import User +from app.plugins.auth.injection import load_route_policy_registry, validate_route_policy_registry +from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry, RoutePolicySpec +from app.plugins.auth.injection.route_guard import enforce_route_policy +from app.plugins.auth.injection.route_injector import install_route_guards + + +def test_load_route_policy_registry_flattens_yaml_sections() -> None: + registry = load_route_policy_registry() + + public_spec = registry.get("POST", "/api/v1/auth/login/local") + assert public_spec is not None + assert public_spec.public is True + + run_stream_spec = registry.get("GET", "/api/threads/{thread_id}/runs/{run_id}/stream") + assert run_stream_spec is not None + assert run_stream_spec.capability == "runs:read" + assert run_stream_spec.policies == ("owner:run",) + + post_stream_spec = registry.get("POST", "/api/threads/{thread_id}/runs/{run_id}/stream") + assert post_stream_spec == run_stream_spec + + +def test_validate_route_policy_registry_rejects_missing_entry() -> None: + app = FastAPI() + router = APIRouter() + + @router.get("/api/needs-policy") + async def needs_policy() -> dict[str, bool]: + return {"ok": True} + + app.include_router(router) + registry = RoutePolicyRegistry([]) + + with pytest.raises(RuntimeError, match="Missing route policy entries"): + validate_route_policy_registry(app, registry) + + +def test_install_route_guards_appends_route_dependency() -> None: + app = FastAPI() + router = APIRouter() + + @router.get("/api/demo") + async def demo() -> dict[str, bool]: + return {"ok": True} + + app.include_router(router) + + route = next(route for route in app.routes if getattr(route, "path", None) == "/api/demo") + before = len(route.dependencies) + + install_route_guards(app) + + assert len(route.dependencies) == before + 1 + assert route.dependencies[-1].dependency is enforce_route_policy + + +@pytest.mark.anyio +async def test_enforce_route_policy_denies_missing_capability() -> None: + user = User(id=uuid4(), email="user@example.com", password_hash="hash") + auth = AuthContext(user=user, permissions=["threads:read"]) + registry = RoutePolicyRegistry( + [ + SimpleNamespace( + method="GET", + path="/api/threads/{thread_id}/uploads/list", + spec=RoutePolicySpec(capability="threads:delete"), + matches_request=lambda *_args, **_kwargs: True, + ) + ] + ) + + app = SimpleNamespace(state=SimpleNamespace(auth_route_policy_registry=registry)) + scope = { + "type": "http", + "method": "GET", + "path": "/api/threads/thread-1/uploads/list", + "headers": [], + "app": app, + "route": SimpleNamespace(path="/api/threads/{thread_id}/uploads/list"), + "path_params": {"thread_id": "thread-1"}, + "auth": auth, + } + request = Request(scope) + request.state.auth = auth + + with pytest.raises(Exception) as exc_info: + await enforce_route_policy(request) + + assert getattr(exc_info.value, "status_code", None) == 403 + + +@pytest.mark.anyio +async def test_enforce_route_policy_runs_owner_policy(monkeypatch: pytest.MonkeyPatch) -> None: + user = User(id=uuid4(), email="user@example.com", password_hash="hash") + auth = AuthContext(user=user, permissions=["threads:read"]) + registry = RoutePolicyRegistry( + [ + SimpleNamespace( + method="GET", + path="/api/threads/{thread_id}/state", + spec=RoutePolicySpec(capability="threads:read", policies=("owner:thread",)), + matches_request=lambda *_args, **_kwargs: True, + ) + ] + ) + + called: dict[str, object] = {} + + async def fake_owner_check(request: Request, auth_context: AuthContext, *, thread_id: str, require_existing: bool) -> None: + called["request"] = request + called["auth"] = auth_context + called["thread_id"] = thread_id + called["require_existing"] = require_existing + + monkeypatch.setattr("app.plugins.auth.injection.route_guard.require_thread_owner", fake_owner_check) + + app = SimpleNamespace(state=SimpleNamespace(auth_route_policy_registry=registry)) + scope = { + "type": "http", + "method": "GET", + "path": "/api/threads/thread-1/state", + "headers": [], + "app": app, + "route": SimpleNamespace(path="/api/threads/{thread_id}/state"), + "path_params": {"thread_id": "thread-1"}, + "auth": auth, + } + request = Request(scope) + request.state.auth = auth + + await enforce_route_policy(request) + + assert called["thread_id"] == "thread-1" + assert called["auth"] is auth + assert called["require_existing"] is True diff --git a/backend/tests/unittest/test_auth_service.py b/backend/tests/unittest/test_auth_service.py new file mode 100644 index 000000000..8a760d99a --- /dev/null +++ b/backend/tests/unittest/test_auth_service.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.plugins.auth.domain.service import AuthService, AuthServiceError +from app.plugins.auth.storage.models import User as UserModel # noqa: F401 +from store.persistence import MappedBase + + +async def _make_service(tmp_path): + engine = create_async_engine( + f"sqlite+aiosqlite:///{tmp_path / 'auth-service.db'}", + future=True, + ) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + return engine, AuthService(session_factory) + + +class TestAuthService: + @pytest.mark.anyio + async def test_register_and_login_local(self, tmp_path): + engine, service = await _make_service(tmp_path) + try: + created = await service.register("user@example.com", "Str0ng!Pass99") + logged_in = await service.login_local("user@example.com", "Str0ng!Pass99") + finally: + await engine.dispose() + + assert created.email == "user@example.com" + assert created.password_hash is not None + assert logged_in.id == created.id + + @pytest.mark.anyio + async def test_register_duplicate_email_raises(self, tmp_path): + engine, service = await _make_service(tmp_path) + try: + await service.register("dupe@example.com", "Str0ng!Pass99") + with pytest.raises(AuthServiceError) as exc_info: + await service.register("dupe@example.com", "An0ther!Pass99") + finally: + await engine.dispose() + + assert exc_info.value.code.value == "email_already_exists" + + @pytest.mark.anyio + async def test_initialize_admin_only_once(self, tmp_path): + engine, service = await _make_service(tmp_path) + try: + admin = await service.initialize_admin("admin@example.com", "Str0ng!Pass99") + with pytest.raises(AuthServiceError) as exc_info: + await service.initialize_admin("other@example.com", "An0ther!Pass99") + finally: + await engine.dispose() + + assert admin.system_role == "admin" + assert admin.needs_setup is False + assert exc_info.value.code.value == "system_already_initialized" + + @pytest.mark.anyio + async def test_change_password_updates_token_version_and_clears_setup(self, tmp_path): + engine, service = await _make_service(tmp_path) + try: + user = await service.register("setup@example.com", "Str0ng!Pass99") + user.needs_setup = True + updated = await service.change_password( + user, + current_password="Str0ng!Pass99", + new_password="N3wer!Pass99", + new_email="final@example.com", + ) + relogged = await service.login_local("final@example.com", "N3wer!Pass99") + finally: + await engine.dispose() + + assert updated.email == "final@example.com" + assert updated.needs_setup is False + assert updated.token_version == 1 + assert relogged.id == updated.id diff --git a/backend/tests/test_auth_type_system.py b/backend/tests/unittest/test_auth_type_system.py similarity index 73% rename from backend/tests/test_auth_type_system.py rename to backend/tests/unittest/test_auth_type_system.py index 226d3812c..f98294e2e 100644 --- a/backend/tests/test_auth_type_system.py +++ b/backend/tests/unittest/test_auth_type_system.py @@ -16,10 +16,11 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from pydantic import ValidationError -from app.gateway.auth.config import AuthConfig, set_auth_config -from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError -from app.gateway.auth.jwt import decode_token -from app.gateway.csrf_middleware import ( +from app.plugins.auth.domain.config import AuthConfig +from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse, TokenError +from app.plugins.auth.domain.jwt import decode_token +from app.plugins.auth.runtime.config_state import set_auth_config +from app.plugins.auth.security.csrf import ( CSRF_COOKIE_NAME, CSRF_HEADER_NAME, CSRFMiddleware, @@ -34,28 +35,8 @@ _TEST_SECRET = "test-secret-for-auth-type-system-tests-min32" @pytest.fixture(autouse=True) def _persistence_engine(tmp_path): - """Initialise a per-test SQLite engine + reset cached provider singletons. - - The auth tests call real HTTP handlers that go through - ``SQLiteUserRepository`` → ``get_session_factory``. Each test gets - a fresh DB plus a clean ``deps._cached_*`` so the cached provider - does not hold a dangling reference to the previous test's engine. - """ - import asyncio - - from app.gateway import deps - from deerflow.persistence.engine import close_engine, init_engine - - url = f"sqlite+aiosqlite:///{tmp_path}/auth_types.db" - asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))) - deps._cached_local_provider = None - deps._cached_repo = None - try: - yield - finally: - deps._cached_local_provider = None - deps._cached_repo = None - asyncio.run(close_engine()) + """Per-test auth config fixture placeholder.""" + yield def _setup_config(): @@ -174,7 +155,7 @@ def test_decode_token_invalid_sig_maps_to_token_invalid_code(): import jwt as pyjwt payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)} - token = pyjwt.encode(payload, "wrong-key", algorithm="HS256") + token = pyjwt.encode(payload, "wrong-key-for-tests-minimum-32-chars", algorithm="HS256") result = decode_token(token) assert result == TokenError.INVALID_SIGNATURE @@ -197,7 +178,7 @@ def test_decode_token_malformed_maps_to_token_invalid_code(): def test_login_response_model_has_no_access_token(): """LoginResponse should NOT contain access_token field (RFC-001).""" - from app.gateway.routers.auth import LoginResponse + from app.plugins.auth.api.schemas import LoginResponse resp = LoginResponse(expires_in=604800) d = resp.model_dump() @@ -208,7 +189,7 @@ def test_login_response_model_has_no_access_token(): def test_login_response_model_fields(): """LoginResponse has expires_in and needs_setup.""" - from app.gateway.routers.auth import LoginResponse + from app.plugins.auth.api.schemas import LoginResponse fields = set(LoginResponse.model_fields.keys()) assert fields == {"expires_in", "needs_setup"} @@ -219,7 +200,7 @@ def test_login_response_model_fields(): def test_auth_config_token_expiry_used_in_login_response(): """LoginResponse.expires_in should come from config.token_expiry_days.""" - from app.gateway.routers.auth import LoginResponse + from app.plugins.auth.api.schemas import LoginResponse expected_seconds = 14 * 24 * 3600 resp = LoginResponse(expires_in=expected_seconds) @@ -231,7 +212,7 @@ def test_auth_config_token_expiry_used_in_login_response(): def test_user_response_system_role_literal(): """UserResponse.system_role should only accept 'admin' or 'user'.""" - from app.gateway.auth.models import UserResponse + from app.plugins.auth.domain.models import UserResponse # Valid roles resp = UserResponse(id="1", email="a@b.com", system_role="admin") @@ -243,7 +224,7 @@ def test_user_response_system_role_literal(): def test_user_response_rejects_invalid_role(): """UserResponse should reject invalid system_role values.""" - from app.gateway.auth.models import UserResponse + from app.plugins.auth.domain.models import UserResponse with pytest.raises(ValidationError): UserResponse(id="1", email="a@b.com", system_role="superadmin") @@ -263,7 +244,7 @@ def test_get_current_user_no_cookie_returns_not_authenticated(): from fastapi import HTTPException - from app.gateway.deps import get_current_user_from_request + from app.plugins.auth.security.dependencies import get_current_user_from_request mock_request = type("MockRequest", (), {"cookies": {}})() with pytest.raises(HTTPException) as exc_info: @@ -279,7 +260,7 @@ def test_get_current_user_expired_token_returns_token_expired(): from fastapi import HTTPException - from app.gateway.deps import get_current_user_from_request + from app.plugins.auth.security.dependencies import get_current_user_from_request _setup_config() expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)} @@ -299,11 +280,11 @@ def test_get_current_user_invalid_token_returns_token_invalid(): from fastapi import HTTPException - from app.gateway.deps import get_current_user_from_request + from app.plugins.auth.security.dependencies import get_current_user_from_request _setup_config() payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)} - token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256") + token = pyjwt.encode(payload, "wrong-secret-key-for-tests-minimum-32", algorithm="HS256") mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})() with pytest.raises(HTTPException) as exc_info: @@ -319,7 +300,7 @@ def test_get_current_user_malformed_token_returns_token_invalid(): from fastapi import HTTPException - from app.gateway.deps import get_current_user_from_request + from app.plugins.auth.security.dependencies import get_current_user_from_request _setup_config() mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})() @@ -380,19 +361,19 @@ def test_get_auth_config_missing_env_var_generates_ephemeral(caplog): """get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset.""" import logging - import app.gateway.auth.config as cfg + import app.plugins.auth.runtime.config_state as cfg + from app.plugins.auth.runtime.config_state import reset_auth_config - old = cfg._auth_config - cfg._auth_config = None try: with patch.dict(os.environ, {}, clear=True): os.environ.pop("AUTH_JWT_SECRET", None) with caplog.at_level(logging.WARNING): + reset_auth_config() config = cfg.get_auth_config() assert config.jwt_secret assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages) finally: - cfg._auth_config = old + reset_auth_config() # ── CSRF middleware integration (unhappy paths) ────────────────────── @@ -485,7 +466,7 @@ def test_csrf_middleware_sets_cookie_on_auth_endpoint(): def test_user_response_missing_required_fields(): """UserResponse with missing fields → ValidationError.""" - from app.gateway.auth.models import UserResponse + from app.plugins.auth.domain.models import UserResponse with pytest.raises(ValidationError): UserResponse(id="1") # missing email, system_role @@ -496,7 +477,7 @@ def test_user_response_missing_required_fields(): def test_user_response_empty_string_role_rejected(): """Empty string is not a valid role.""" - from app.gateway.auth.models import UserResponse + from app.plugins.auth.domain.models import UserResponse with pytest.raises(ValidationError): UserResponse(id="1", email="a@b.com", system_role="") @@ -514,20 +495,15 @@ def _make_auth_app(): return create_app() -def _get_auth_client(): - """Get TestClient for auth API contract tests.""" - return TestClient(_make_auth_app()) - - def test_api_auth_me_no_cookie_returns_structured_401(): """/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}.""" _setup_config() - client = _get_auth_client() - resp = client.get("/api/v1/auth/me") - assert resp.status_code == 401 - body = resp.json() - assert body["detail"]["code"] == "not_authenticated" - assert "message" in body["detail"] + with TestClient(_make_auth_app()) as client: + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 + body = resp.json() + assert body["detail"]["code"] == "not_authenticated" + assert "message" in body["detail"] def test_api_auth_me_expired_token_returns_structured_401(): @@ -536,75 +512,70 @@ def test_api_auth_me_expired_token_returns_structured_401(): expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)} token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256") - client = _get_auth_client() - client.cookies.set("access_token", token) - resp = client.get("/api/v1/auth/me") - assert resp.status_code == 401 - body = resp.json() - assert body["detail"]["code"] == "token_expired" + with TestClient(_make_auth_app()) as client: + client.cookies.set("access_token", token) + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 + body = resp.json() + assert body["detail"]["code"] == "token_expired" def test_api_auth_me_invalid_sig_returns_structured_401(): """/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}.""" _setup_config() payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)} - token = pyjwt.encode(payload, "wrong-key", algorithm="HS256") + token = pyjwt.encode(payload, "wrong-key-for-tests-minimum-32-chars", algorithm="HS256") - client = _get_auth_client() - client.cookies.set("access_token", token) - resp = client.get("/api/v1/auth/me") - assert resp.status_code == 401 - body = resp.json() - assert body["detail"]["code"] == "token_invalid" + with TestClient(_make_auth_app()) as client: + client.cookies.set("access_token", token) + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 + body = resp.json() + assert body["detail"]["code"] == "token_invalid" def test_api_login_bad_credentials_returns_structured_401(): """Login with wrong password → 401 with {code: 'invalid_credentials'}.""" _setup_config() - client = _get_auth_client() - resp = client.post( - "/api/v1/auth/login/local", - data={"username": "nonexistent@test.com", "password": "wrongpassword"}, - ) - assert resp.status_code == 401 - body = resp.json() - assert body["detail"]["code"] == "invalid_credentials" + with TestClient(_make_auth_app()) as client: + resp = client.post( + "/api/v1/auth/login/local", + data={"username": "nonexistent@test.com", "password": "wrongpassword"}, + ) + assert resp.status_code == 401 + body = resp.json() + assert body["detail"]["code"] == "invalid_credentials" def test_api_login_success_no_token_in_body(): """Successful login → response body has expires_in but NOT access_token.""" _setup_config() - client = _get_auth_client() - # Register first - client.post( - "/api/v1/auth/register", - json={"email": "contract-test@test.com", "password": "securepassword123"}, - ) - # Login - resp = client.post( - "/api/v1/auth/login/local", - data={"username": "contract-test@test.com", "password": "securepassword123"}, - ) - assert resp.status_code == 200 - body = resp.json() - assert "expires_in" in body - assert "access_token" not in body - # Token should be in cookie, not body - assert "access_token" in resp.cookies + with TestClient(_make_auth_app()) as client: + client.post( + "/api/v1/auth/register", + json={"email": "contract-test@test.com", "password": "securepassword123"}, + ) + resp = client.post( + "/api/v1/auth/login/local", + data={"username": "contract-test@test.com", "password": "securepassword123"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "expires_in" in body + assert "access_token" not in body + assert "access_token" in resp.cookies def test_api_register_duplicate_returns_structured_400(): """Register with duplicate email → 400 with {code: 'email_already_exists'}.""" _setup_config() - client = _get_auth_client() - email = "dup-contract-test@test.com" - # First register - client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"}) - # Duplicate - resp = client.post("/api/v1/auth/register", json={"email": email, "password": "AnotherStr0ngPwd!"}) - assert resp.status_code == 400 - body = resp.json() - assert body["detail"]["code"] == "email_already_exists" + with TestClient(_make_auth_app()) as client: + email = "dup-contract-test@test.com" + client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"}) + resp = client.post("/api/v1/auth/register", json={"email": email, "password": "AnotherStr0ngPwd!"}) + assert resp.status_code == 400 + body = resp.json() + assert body["detail"]["code"] == "email_already_exists" # ── Cookie security: HTTP vs HTTPS ──────────────────────────────────── @@ -622,80 +593,80 @@ def _get_set_cookie_headers(resp) -> list[str]: def test_register_http_cookie_httponly_true_secure_false(): """HTTP register → access_token cookie is httponly=True, secure=False, no max_age.""" _setup_config() - client = _get_auth_client() - resp = client.post( - "/api/v1/auth/register", - json={"email": _unique_email("http-cookie"), "password": "Tr0ub4dor3a"}, - ) - assert resp.status_code == 201 - cookie_header = resp.headers.get("set-cookie", "") - assert "access_token=" in cookie_header - assert "httponly" in cookie_header.lower() - assert "secure" not in cookie_header.lower().replace("samesite", "") + with TestClient(_make_auth_app()) as client: + resp = client.post( + "/api/v1/auth/register", + json={"email": _unique_email("http-cookie"), "password": "Tr0ub4dor3a"}, + ) + assert resp.status_code == 201 + cookie_header = resp.headers.get("set-cookie", "") + assert "access_token=" in cookie_header + assert "httponly" in cookie_header.lower() + assert "secure" not in cookie_header.lower().replace("samesite", "") def test_register_https_cookie_httponly_true_secure_true(): """HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age.""" _setup_config() - client = _get_auth_client() - resp = client.post( - "/api/v1/auth/register", - json={"email": _unique_email("https-cookie"), "password": "Tr0ub4dor3a"}, - headers={"x-forwarded-proto": "https"}, - ) - assert resp.status_code == 201 - cookie_header = resp.headers.get("set-cookie", "") - assert "access_token=" in cookie_header - assert "httponly" in cookie_header.lower() - assert "secure" in cookie_header.lower() - assert "max-age" in cookie_header.lower() + with TestClient(_make_auth_app()) as client: + resp = client.post( + "/api/v1/auth/register", + json={"email": _unique_email("https-cookie"), "password": "Tr0ub4dor3a"}, + headers={"x-forwarded-proto": "https"}, + ) + assert resp.status_code == 201 + cookie_header = resp.headers.get("set-cookie", "") + assert "access_token=" in cookie_header + assert "httponly" in cookie_header.lower() + assert "secure" in cookie_header.lower() + assert "max-age" in cookie_header.lower() def test_login_https_sets_secure_cookie(): """HTTPS login → access_token cookie has secure flag.""" _setup_config() - client = _get_auth_client() - email = _unique_email("https-login") - client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"}) - resp = client.post( - "/api/v1/auth/login/local", - data={"username": email, "password": "Tr0ub4dor3a"}, - headers={"x-forwarded-proto": "https"}, - ) - assert resp.status_code == 200 - cookie_header = resp.headers.get("set-cookie", "") - assert "access_token=" in cookie_header - assert "httponly" in cookie_header.lower() - assert "secure" in cookie_header.lower() + with TestClient(_make_auth_app()) as client: + email = _unique_email("https-login") + client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"}) + resp = client.post( + "/api/v1/auth/login/local", + data={"username": email, "password": "Tr0ub4dor3a"}, + headers={"x-forwarded-proto": "https"}, + ) + assert resp.status_code == 200 + cookie_header = resp.headers.get("set-cookie", "") + assert "access_token=" in cookie_header + assert "httponly" in cookie_header.lower() + assert "secure" in cookie_header.lower() def test_csrf_cookie_secure_on_https(): """HTTPS register → csrf_token cookie has secure flag but NOT httponly.""" _setup_config() - client = _get_auth_client() - resp = client.post( - "/api/v1/auth/register", - json={"email": _unique_email("csrf-https"), "password": "Tr0ub4dor3a"}, - headers={"x-forwarded-proto": "https"}, - ) - assert resp.status_code == 201 - csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h] - assert csrf_cookies, "csrf_token cookie not set on HTTPS register" - csrf_header = csrf_cookies[0] - assert "secure" in csrf_header.lower() - assert "httponly" not in csrf_header.lower() + with TestClient(_make_auth_app()) as client: + resp = client.post( + "/api/v1/auth/register", + json={"email": _unique_email("csrf-https"), "password": "Tr0ub4dor3a"}, + headers={"x-forwarded-proto": "https"}, + ) + assert resp.status_code == 201 + csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h] + assert csrf_cookies, "csrf_token cookie not set on HTTPS register" + csrf_header = csrf_cookies[0] + assert "secure" in csrf_header.lower() + assert "httponly" not in csrf_header.lower() def test_csrf_cookie_not_secure_on_http(): """HTTP register → csrf_token cookie does NOT have secure flag.""" _setup_config() - client = _get_auth_client() - resp = client.post( - "/api/v1/auth/register", - json={"email": _unique_email("csrf-http"), "password": "Tr0ub4dor3a"}, - ) - assert resp.status_code == 201 - csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h] - assert csrf_cookies, "csrf_token cookie not set on HTTP register" - csrf_header = csrf_cookies[0] - assert "secure" not in csrf_header.lower().replace("samesite", "") + with TestClient(_make_auth_app()) as client: + resp = client.post( + "/api/v1/auth/register", + json={"email": _unique_email("csrf-http"), "password": "Tr0ub4dor3a"}, + ) + assert resp.status_code == 201 + csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h] + assert csrf_cookies, "csrf_token cookie not set on HTTP register" + csrf_header = csrf_cookies[0] + assert "secure" not in csrf_header.lower().replace("samesite", "") diff --git a/backend/tests/test_channel_file_attachments.py b/backend/tests/unittest/test_channel_file_attachments.py similarity index 100% rename from backend/tests/test_channel_file_attachments.py rename to backend/tests/unittest/test_channel_file_attachments.py diff --git a/backend/tests/test_channels.py b/backend/tests/unittest/test_channels.py similarity index 100% rename from backend/tests/test_channels.py rename to backend/tests/unittest/test_channels.py diff --git a/backend/tests/test_clarification_middleware.py b/backend/tests/unittest/test_clarification_middleware.py similarity index 100% rename from backend/tests/test_clarification_middleware.py rename to backend/tests/unittest/test_clarification_middleware.py diff --git a/backend/tests/test_claude_provider_oauth_billing.py b/backend/tests/unittest/test_claude_provider_oauth_billing.py similarity index 100% rename from backend/tests/test_claude_provider_oauth_billing.py rename to backend/tests/unittest/test_claude_provider_oauth_billing.py diff --git a/backend/tests/test_cli_auth_providers.py b/backend/tests/unittest/test_cli_auth_providers.py similarity index 100% rename from backend/tests/test_cli_auth_providers.py rename to backend/tests/unittest/test_cli_auth_providers.py diff --git a/backend/tests/test_client.py b/backend/tests/unittest/test_client.py similarity index 97% rename from backend/tests/test_client.py rename to backend/tests/unittest/test_client.py index d22e36d17..208213533 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/unittest/test_client.py @@ -38,6 +38,8 @@ def mock_app_config(): config = MagicMock() config.models = [model] + config.database = None + config.checkpointer = None return config @@ -86,10 +88,12 @@ class TestClientInit: def test_custom_config_path(self, mock_app_config): with ( patch("deerflow.client.reload_app_config") as mock_reload, + patch("store.config.app_config.reload_app_config") as mock_storage_reload, patch("deerflow.client.get_app_config", return_value=mock_app_config), ): DeerFlowClient(config_path="/tmp/custom.yaml") mock_reload.assert_called_once_with("/tmp/custom.yaml") + mock_storage_reload.assert_called_once_with("/tmp/custom.yaml") def test_checkpointer_stored(self, mock_app_config): cp = MagicMock() @@ -97,6 +101,59 @@ class TestClientInit: c = DeerFlowClient(checkpointer=cp) assert c._checkpointer is cp + def test_resolve_checkpointer_config_reads_storage_sqlite(self, mock_app_config): + storage_app_config = MagicMock() + storage = MagicMock() + storage.driver = "sqlite" + storage.sqlite_storage_path = "/tmp/deerflow.db" + storage_app_config.storage = storage + with patch("deerflow.client.get_app_config", return_value=mock_app_config): + c = DeerFlowClient() + + with patch("store.config.app_config.get_app_config", return_value=storage_app_config): + resolved = c._resolve_checkpointer_config() + + assert resolved == { + "backend": "sqlite", + "connection_string": "/tmp/deerflow.db", + } + + def test_resolve_checkpointer_config_reads_storage_postgres(self, mock_app_config): + storage_app_config = MagicMock() + storage = MagicMock() + storage.driver = "postgres" + storage.username = "user" + storage.password = "pass" + storage.host = "localhost" + storage.port = 5432 + storage.db_name = "deerflow" + storage_app_config.storage = storage + + with patch("deerflow.client.get_app_config", return_value=mock_app_config): + c = DeerFlowClient() + + with patch("store.config.app_config.get_app_config", return_value=storage_app_config): + resolved = c._resolve_checkpointer_config() + + assert resolved == { + "backend": "postgres", + "connection_string": "postgresql://user:pass@localhost:5432/deerflow", + } + + def test_resolve_checkpointer_config_rejects_mysql(self, mock_app_config): + storage_app_config = MagicMock() + storage = MagicMock() + storage.driver = "mysql" + storage_app_config.storage = storage + + with patch("deerflow.client.get_app_config", return_value=mock_app_config): + c = DeerFlowClient() + + with patch("store.config.app_config.get_app_config", return_value=storage_app_config): + with pytest.raises(ValueError, match="does not support a MySQL checkpointer"): + c._resolve_checkpointer_config() + + # --------------------------------------------------------------------------- # list_models / list_skills / get_memory @@ -817,7 +874,7 @@ class TestEnsureAgent: patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares, patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt, patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), + patch.object(client, "_get_active_checkpointer", return_value=MagicMock()), ): client._agent_name = "custom-agent" client._available_skills = {"test_skill"} @@ -842,7 +899,7 @@ class TestEnsureAgent: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer), + patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer), ): client._ensure_agent(config) @@ -867,7 +924,7 @@ class TestEnsureAgent: patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), + patch.object(client, "_get_active_checkpointer", return_value=MagicMock()), ): client._ensure_agent(config) @@ -886,7 +943,7 @@ class TestEnsureAgent: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None), + patch.object(client, "_get_active_checkpointer", return_value=None), ): client._ensure_agent(config) @@ -1015,7 +1072,7 @@ class TestThreadQueries: mock_checkpointer = MagicMock() mock_checkpointer.list.return_value = [] - with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer): + with patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer): # No internal checkpointer, should fetch from provider result = client.list_threads() @@ -1069,7 +1126,7 @@ class TestThreadQueries: mock_checkpointer = MagicMock() mock_checkpointer.list.return_value = [] - with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer): + with patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer): result = client.get_thread("t99") assert result["thread_id"] == "t99" @@ -1490,7 +1547,7 @@ class TestUploads: class TestArtifacts: def test_get_artifact(self, client): - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) @@ -1506,7 +1563,7 @@ class TestArtifacts: assert "text" in mime def test_get_artifact_not_found(self, client): - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) @@ -1522,7 +1579,7 @@ class TestArtifacts: client.get_artifact("t1", "bad/path/file.txt") def test_get_artifact_path_traversal(self, client): - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) @@ -1711,7 +1768,7 @@ class TestScenarioFileLifecycle: def test_upload_then_read_artifact(self, client): """Upload a file, simulate agent producing artifact, read it back.""" - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp) @@ -1859,7 +1916,7 @@ class TestScenarioAgentRecreation: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), + patch.object(client, "_get_active_checkpointer", return_value=MagicMock()), ): client._ensure_agent(config_a) first_agent = client._agent @@ -1887,7 +1944,7 @@ class TestScenarioAgentRecreation: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), + patch.object(client, "_get_active_checkpointer", return_value=MagicMock()), ): client._ensure_agent(config) client._ensure_agent(config) @@ -1912,7 +1969,7 @@ class TestScenarioAgentRecreation: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), + patch.object(client, "_get_active_checkpointer", return_value=MagicMock()), ): client._ensure_agent(config) client.reset_agent() @@ -1970,7 +2027,7 @@ class TestScenarioThreadIsolation: def test_artifacts_isolated_per_thread(self, client): """Artifacts in thread-A are not accessible from thread-B.""" - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) @@ -2882,7 +2939,7 @@ class TestUploadDeleteSymlink: class TestArtifactHardening: def test_artifact_directory_rejected(self, client): """get_artifact rejects paths that resolve to a directory.""" - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) @@ -2896,7 +2953,7 @@ class TestArtifactHardening: def test_artifact_leading_slash_stripped(self, client): """Paths with leading slash are handled correctly.""" - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) @@ -3015,7 +3072,7 @@ class TestBugArtifactPrefixMatchTooLoose: def test_exact_prefix_without_subpath_accepted(self, client): """Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix).""" - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id with tempfile.TemporaryDirectory() as tmp: paths = Paths(base_dir=tmp) diff --git a/backend/tests/test_client_e2e.py b/backend/tests/unittest/test_client_e2e.py similarity index 99% rename from backend/tests/test_client_e2e.py rename to backend/tests/unittest/test_client_e2e.py index 6c688933a..f5170727d 100644 --- a/backend/tests/test_client_e2e.py +++ b/backend/tests/unittest/test_client_e2e.py @@ -262,7 +262,7 @@ class TestFileUploadIntegration: # Physically exists from deerflow.config.paths import get_paths - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists() @@ -473,7 +473,7 @@ class TestArtifactAccess: def test_get_artifact_happy_path(self, e2e_env): """Write a file to outputs, then read it back via get_artifact().""" from deerflow.config.paths import get_paths - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id c = DeerFlowClient(checkpointer=None, thinking_enabled=False) tid = str(uuid.uuid4()) @@ -490,7 +490,7 @@ class TestArtifactAccess: def test_get_artifact_nested_path(self, e2e_env): """Artifacts in subdirectories are accessible.""" from deerflow.config.paths import get_paths - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id c = DeerFlowClient(checkpointer=None, thinking_enabled=False) tid = str(uuid.uuid4()) diff --git a/backend/tests/test_client_live.py b/backend/tests/unittest/test_client_live.py similarity index 98% rename from backend/tests/test_client_live.py rename to backend/tests/unittest/test_client_live.py index 0271ebf21..0a40633eb 100644 --- a/backend/tests/test_client_live.py +++ b/backend/tests/unittest/test_client_live.py @@ -8,6 +8,7 @@ They are skipped in CI and must be run explicitly: import json import os +import warnings from pathlib import Path import pytest @@ -16,6 +17,12 @@ from deerflow.client import DeerFlowClient, StreamEvent from deerflow.sandbox.security import is_host_bash_allowed from deerflow.uploads.manager import PathTraversalError +warnings.filterwarnings( + "ignore", + message=r"Pydantic serializer warnings:.*field_name='context'.*", + category=UserWarning, +) + # Skip entire module in CI or when no config.yaml exists _skip_reason = None if os.environ.get("CI"): diff --git a/backend/tests/test_codex_provider.py b/backend/tests/unittest/test_codex_provider.py similarity index 100% rename from backend/tests/test_codex_provider.py rename to backend/tests/unittest/test_codex_provider.py diff --git a/backend/tests/test_config_version.py b/backend/tests/unittest/test_config_version.py similarity index 100% rename from backend/tests/test_config_version.py rename to backend/tests/unittest/test_config_version.py diff --git a/backend/tests/test_converters.py b/backend/tests/unittest/test_converters.py similarity index 100% rename from backend/tests/test_converters.py rename to backend/tests/unittest/test_converters.py diff --git a/backend/tests/test_create_deerflow_agent.py b/backend/tests/unittest/test_create_deerflow_agent.py similarity index 100% rename from backend/tests/test_create_deerflow_agent.py rename to backend/tests/unittest/test_create_deerflow_agent.py diff --git a/backend/tests/test_create_deerflow_agent_live.py b/backend/tests/unittest/test_create_deerflow_agent_live.py similarity index 100% rename from backend/tests/test_create_deerflow_agent_live.py rename to backend/tests/unittest/test_create_deerflow_agent_live.py diff --git a/backend/tests/test_credential_loader.py b/backend/tests/unittest/test_credential_loader.py similarity index 100% rename from backend/tests/test_credential_loader.py rename to backend/tests/unittest/test_credential_loader.py diff --git a/backend/tests/test_custom_agent.py b/backend/tests/unittest/test_custom_agent.py similarity index 100% rename from backend/tests/test_custom_agent.py rename to backend/tests/unittest/test_custom_agent.py diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/unittest/test_dangling_tool_call_middleware.py similarity index 100% rename from backend/tests/test_dangling_tool_call_middleware.py rename to backend/tests/unittest/test_dangling_tool_call_middleware.py diff --git a/backend/tests/test_docker_sandbox_mode_detection.py b/backend/tests/unittest/test_docker_sandbox_mode_detection.py similarity index 100% rename from backend/tests/test_docker_sandbox_mode_detection.py rename to backend/tests/unittest/test_docker_sandbox_mode_detection.py diff --git a/backend/tests/test_doctor.py b/backend/tests/unittest/test_doctor.py similarity index 100% rename from backend/tests/test_doctor.py rename to backend/tests/unittest/test_doctor.py diff --git a/backend/tests/test_exa_tools.py b/backend/tests/unittest/test_exa_tools.py similarity index 100% rename from backend/tests/test_exa_tools.py rename to backend/tests/unittest/test_exa_tools.py diff --git a/backend/tests/unittest/test_feedback.py b/backend/tests/unittest/test_feedback.py new file mode 100644 index 000000000..6f80b47ae --- /dev/null +++ b/backend/tests/unittest/test_feedback.py @@ -0,0 +1,392 @@ +"""Tests for current feedback storage adapters and follow-up association.""" + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.infra.storage import AppRunEventStore, FeedbackStoreAdapter, RunStoreAdapter +from store.persistence import MappedBase + + +async def _make_feedback_repo(tmp_path): + engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + + class _FeedbackRepoCompat: + def __init__(self, session_factory): + self._repo = FeedbackStoreAdapter(session_factory) + + async def create(self, **kwargs): + return await self._repo.create( + run_id=kwargs["run_id"], + thread_id=kwargs["thread_id"], + rating=kwargs["rating"], + owner_id=kwargs.get("owner_id"), + user_id=kwargs.get("user_id"), + message_id=kwargs.get("message_id"), + comment=kwargs.get("comment"), + ) + + async def get(self, feedback_id): + return await self._repo.get(feedback_id) + + async def list_by_run(self, thread_id, run_id, user_id=None, limit=100): + rows = await self._repo.list_by_run(thread_id, run_id, user_id=user_id, limit=limit) + return rows + + async def list_by_thread(self, thread_id, limit=100): + return await self._repo.list_by_thread(thread_id, limit=limit) + + async def delete(self, feedback_id): + return await self._repo.delete(feedback_id) + + async def aggregate_by_run(self, thread_id, run_id): + return await self._repo.aggregate_by_run(thread_id, run_id) + + async def upsert(self, **kwargs): + return await self._repo.upsert( + run_id=kwargs["run_id"], + thread_id=kwargs["thread_id"], + rating=kwargs["rating"], + user_id=kwargs.get("user_id"), + comment=kwargs.get("comment"), + ) + + async def delete_by_run(self, *, thread_id, run_id, user_id): + return await self._repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id) + + async def list_by_thread_grouped(self, thread_id, user_id): + return await self._repo.list_by_thread_grouped(thread_id, user_id=user_id) + + return engine, session_factory, _FeedbackRepoCompat(session_factory) + + +# -- FeedbackRepository -- + + +class TestFeedbackRepository: + @pytest.mark.anyio + async def test_create_positive(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + record = await repo.create(run_id="r1", thread_id="t1", rating=1) + assert record["feedback_id"] + assert record["rating"] == 1 + assert record["run_id"] == "r1" + assert record["thread_id"] == "t1" + assert "created_at" in record + await engine.dispose() + + @pytest.mark.anyio + async def test_create_negative_with_comment(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + record = await repo.create( + run_id="r1", + thread_id="t1", + rating=-1, + comment="Response was inaccurate", + ) + assert record["rating"] == -1 + assert record["comment"] == "Response was inaccurate" + await engine.dispose() + + @pytest.mark.anyio + async def test_create_with_message_id(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42") + assert record["message_id"] == "msg-42" + await engine.dispose() + + @pytest.mark.anyio + async def test_create_with_owner(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") + assert record["user_id"] == "user-1" + await engine.dispose() + + @pytest.mark.anyio + async def test_create_uses_owner_id_fallback(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="owner-1") + assert record["user_id"] == "owner-1" + assert record["owner_id"] == "owner-1" + await engine.dispose() + + @pytest.mark.anyio + async def test_create_invalid_rating_zero(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + with pytest.raises(ValueError): + await repo.create(run_id="r1", thread_id="t1", rating=0) + await engine.dispose() + + @pytest.mark.anyio + async def test_create_invalid_rating_five(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + with pytest.raises(ValueError): + await repo.create(run_id="r1", thread_id="t1", rating=5) + await engine.dispose() + + @pytest.mark.anyio + async def test_get(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + created = await repo.create(run_id="r1", thread_id="t1", rating=1) + fetched = await repo.get(created["feedback_id"]) + assert fetched is not None + assert fetched["feedback_id"] == created["feedback_id"] + assert fetched["rating"] == 1 + await engine.dispose() + + @pytest.mark.anyio + async def test_get_nonexistent(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + assert await repo.get("nonexistent") is None + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_run(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") + await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2") + await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1") + results = await repo.list_by_run("t1", "r1", user_id=None) + assert len(results) == 2 + assert all(r["run_id"] == "r1" for r in results) + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_run_filters_thread_even_with_same_run_id(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") + await repo.create(run_id="r1", thread_id="t2", rating=-1, user_id="user-2") + results = await repo.list_by_run("t1", "r1", user_id=None) + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_run_respects_limit(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u1") + await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="u2") + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u3") + results = await repo.list_by_run("t1", "r1", user_id=None, limit=2) + assert len(results) == 2 + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_thread(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.create(run_id="r1", thread_id="t1", rating=1) + await repo.create(run_id="r2", thread_id="t1", rating=-1) + await repo.create(run_id="r3", thread_id="t2", rating=1) + results = await repo.list_by_thread("t1") + assert len(results) == 2 + assert all(r["thread_id"] == "t1" for r in results) + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_thread_respects_limit(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.create(run_id="r1", thread_id="t1", rating=1) + await repo.create(run_id="r2", thread_id="t1", rating=-1) + await repo.create(run_id="r3", thread_id="t1", rating=1) + results = await repo.list_by_thread("t1", limit=2) + assert len(results) == 2 + await engine.dispose() + + @pytest.mark.anyio + async def test_delete(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + created = await repo.create(run_id="r1", thread_id="t1", rating=1) + deleted = await repo.delete(created["feedback_id"]) + assert deleted is True + assert await repo.get(created["feedback_id"]) is None + await engine.dispose() + + @pytest.mark.anyio + async def test_delete_nonexistent(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + deleted = await repo.delete("nonexistent") + assert deleted is False + await engine.dispose() + + @pytest.mark.anyio + async def test_aggregate_by_run(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2") + await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3") + stats = await repo.aggregate_by_run("t1", "r1") + assert stats["total"] == 3 + assert stats["positive"] == 2 + assert stats["negative"] == 1 + assert stats["run_id"] == "r1" + await engine.dispose() + + @pytest.mark.anyio + async def test_aggregate_empty(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + stats = await repo.aggregate_by_run("t1", "r1") + assert stats["total"] == 0 + assert stats["positive"] == 0 + assert stats["negative"] == 0 + await engine.dispose() + + @pytest.mark.anyio + async def test_upsert_creates_new(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + assert record["rating"] == 1 + assert record["feedback_id"] + assert record["user_id"] == "u1" + await engine.dispose() + + @pytest.mark.anyio + async def test_upsert_updates_existing(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind") + assert second["feedback_id"] == first["feedback_id"] + assert second["rating"] == -1 + assert second["comment"] == "changed my mind" + await engine.dispose() + + @pytest.mark.anyio + async def test_upsert_different_users_separate(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2") + assert r1["feedback_id"] != r2["feedback_id"] + assert r1["rating"] == 1 + assert r2["rating"] == -1 + await engine.dispose() + + @pytest.mark.anyio + async def test_upsert_invalid_rating(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + with pytest.raises(ValueError): + await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1") + await engine.dispose() + + @pytest.mark.anyio + async def test_delete_by_run(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1") + assert deleted is True + results = await repo.list_by_run("t1", "r1", user_id="u1") + assert len(results) == 0 + await engine.dispose() + + @pytest.mark.anyio + async def test_delete_by_run_nonexistent(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1") + assert deleted is False + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_thread_grouped(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1") + await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1") + grouped = await repo.list_by_thread_grouped("t1", user_id="u1") + assert "r1" in grouped + assert "r2" in grouped + assert "r3" not in grouped + assert grouped["r1"]["rating"] == 1 + assert grouped["r2"]["rating"] == -1 + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_thread_grouped_filters_by_user_when_same_run_id_exists(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u1", comment="mine") + await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="u2", comment="other") + grouped = await repo.list_by_thread_grouped("t1", user_id="u1") + assert grouped["r1"]["user_id"] == "u1" + assert grouped["r1"]["comment"] == "mine" + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_thread_grouped_empty(self, tmp_path): + engine, _, repo = await _make_feedback_repo(tmp_path) + grouped = await repo.list_by_thread_grouped("t1", user_id="u1") + assert grouped == {} + await engine.dispose() + + +# -- Follow-up association -- + + +class TestFollowUpAssociation: + @pytest.mark.anyio + async def test_run_records_follow_up_via_memory_store(self): + """RunStoreAdapter persists follow_up_to_run_id as a first-class field.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False) + store = RunStoreAdapter(session_factory) + await store.create("r1", thread_id="t1", status="success") + await store.create("r2", thread_id="t1", follow_up_to_run_id="r1") + run = await store.get("r2") + assert run is not None + assert run["follow_up_to_run_id"] == "r1" + await engine.dispose() + + @pytest.mark.anyio + async def test_human_message_has_follow_up_metadata(self): + """AppRunEventStore preserves follow_up_to_run_id in message metadata.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False) + event_store = AppRunEventStore(session_factory) + await event_store.put_batch([ + { + "thread_id": "t1", + "run_id": "r2", + "event_type": "human_message", + "category": "message", + "content": "Tell me more about that", + "metadata": {"follow_up_to_run_id": "r1"}, + } + ]) + messages = await event_store.list_messages("t1") + assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1" + await engine.dispose() + + @pytest.mark.anyio + async def test_follow_up_auto_detection_logic(self): + """Simulate the auto-detection: latest successful run becomes follow_up_to.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False) + store = RunStoreAdapter(session_factory) + await store.create("r1", thread_id="t1", status="success") + await store.create("r2", thread_id="t1", status="error") + + # Auto-detect: list_by_thread returns newest first + recent = await store.list_by_thread("t1", limit=1) + follow_up = None + if recent and recent[0].get("status") == "success": + follow_up = recent[0]["run_id"] + # r2 (error) is newest, so no follow_up detected + assert follow_up is None + + # Now add a successful run + await store.create("r3", thread_id="t1", status="success") + recent = await store.list_by_thread("t1", limit=1) + follow_up = None + if recent and recent[0].get("status") == "success": + follow_up = recent[0]["run_id"] + assert follow_up == "r3" + await engine.dispose() diff --git a/backend/tests/test_feishu_parser.py b/backend/tests/unittest/test_feishu_parser.py similarity index 100% rename from backend/tests/test_feishu_parser.py rename to backend/tests/unittest/test_feishu_parser.py diff --git a/backend/tests/test_file_conversion.py b/backend/tests/unittest/test_file_conversion.py similarity index 100% rename from backend/tests/test_file_conversion.py rename to backend/tests/unittest/test_file_conversion.py diff --git a/backend/tests/test_firecrawl_tools.py b/backend/tests/unittest/test_firecrawl_tools.py similarity index 100% rename from backend/tests/test_firecrawl_tools.py rename to backend/tests/unittest/test_firecrawl_tools.py diff --git a/backend/tests/unittest/test_gateway_services.py b/backend/tests/unittest/test_gateway_services.py new file mode 100644 index 000000000..25bd37b9d --- /dev/null +++ b/backend/tests/unittest/test_gateway_services.py @@ -0,0 +1,281 @@ +"""Tests for the current runs service modules.""" + +from __future__ import annotations + +import json + +from app.gateway.routers.langgraph.runs import RunCreateRequest, format_sse +from app.gateway.services.runs.facade_factory import resolve_agent_factory +from app.gateway.services.runs.input.request_adapter import ( + adapt_create_run_request, + adapt_create_stream_request, + adapt_create_wait_request, + adapt_join_stream_request, + adapt_join_wait_request, +) +from app.gateway.services.runs.input.spec_builder import RunSpecBuilder + + +def _builder() -> RunSpecBuilder: + return RunSpecBuilder() + + +def _build_runnable_config( + thread_id: str, + request_config: dict | None, + metadata: dict | None, + *, + assistant_id: str | None = None, + context: dict | None = None, +): + return _builder()._build_runnable_config( # noqa: SLF001 - intentional unit coverage + thread_id=thread_id, + request_config=request_config, + metadata=metadata, + assistant_id=assistant_id, + context=context, + ) + + +def test_format_sse_basic(): + frame = format_sse("metadata", {"run_id": "abc"}) + assert frame.startswith("event: metadata\n") + assert "data: " in frame + parsed = json.loads(frame.split("data: ")[1].split("\n")[0]) + assert parsed["run_id"] == "abc" + + +def test_format_sse_with_event_id(): + frame = format_sse("metadata", {"run_id": "abc"}, event_id="123-0") + assert "id: 123-0" in frame + + +def test_format_sse_end_event_null(): + frame = format_sse("end", None) + assert "data: null" in frame + + +def test_format_sse_no_event_id(): + frame = format_sse("values", {"x": 1}) + assert "id:" not in frame + + +def test_normalize_stream_modes_none(): + assert _builder()._normalize_stream_modes(None) == ["values", "messages"] # noqa: SLF001 + + +def test_normalize_stream_modes_string(): + assert _builder()._normalize_stream_modes("messages-tuple") == ["messages"] # noqa: SLF001 + + +def test_normalize_stream_modes_list(): + assert _builder()._normalize_stream_modes(["values", "messages-tuple"]) == ["values", "messages"] # noqa: SLF001 + + +def test_normalize_stream_modes_empty_list(): + assert _builder()._normalize_stream_modes([]) == [] # noqa: SLF001 + + +def test_normalize_input_none(): + assert _builder()._normalize_input(None) is None # noqa: SLF001 + + +def test_normalize_input_with_messages(): + result = _builder()._normalize_input({"messages": [{"role": "user", "content": "hi"}]}) # noqa: SLF001 + assert len(result["messages"]) == 1 + assert result["messages"][0].content == "hi" + + +def test_normalize_input_passthrough(): + result = _builder()._normalize_input({"custom_key": "value"}) # noqa: SLF001 + assert result == {"custom_key": "value"} + + +def test_build_runnable_config_basic(): + config = _build_runnable_config("thread-1", None, None) + assert config["configurable"]["thread_id"] == "thread-1" + assert config["recursion_limit"] == 100 + + +def test_build_runnable_config_with_overrides(): + config = _build_runnable_config( + "thread-1", + {"configurable": {"model_name": "gpt-4"}, "tags": ["test"]}, + {"user": "alice"}, + ) + assert config["configurable"]["model_name"] == "gpt-4" + assert config["tags"] == ["test"] + assert config["metadata"]["user"] == "alice" + + +def test_build_runnable_config_custom_agent_injects_agent_name(): + config = _build_runnable_config("thread-1", None, None, assistant_id="finalis") + assert config["configurable"]["agent_name"] == "finalis" + + +def test_build_runnable_config_lead_agent_no_agent_name(): + config = _build_runnable_config("thread-1", None, None, assistant_id="lead_agent") + assert "agent_name" not in config["configurable"] + + +def test_build_runnable_config_none_assistant_id_no_agent_name(): + config = _build_runnable_config("thread-1", None, None, assistant_id=None) + assert "agent_name" not in config["configurable"] + + +def test_build_runnable_config_explicit_agent_name_not_overwritten(): + config = _build_runnable_config( + "thread-1", + {"configurable": {"agent_name": "explicit-agent"}}, + None, + assistant_id="other-agent", + ) + assert config["configurable"]["agent_name"] == "explicit-agent" + + +def test_resolve_agent_factory_returns_make_lead_agent(): + from deerflow.agents.lead_agent.agent import make_lead_agent + + assert resolve_agent_factory(None) is make_lead_agent + assert resolve_agent_factory("lead_agent") is make_lead_agent + assert resolve_agent_factory("finalis") is make_lead_agent + assert resolve_agent_factory("custom-agent-123") is make_lead_agent + + +def test_run_create_request_accepts_context(): + body = RunCreateRequest( + input={"messages": [{"role": "user", "content": "hi"}]}, + context={ + "model_name": "deepseek-v3", + "thinking_enabled": True, + "is_plan_mode": True, + "subagent_enabled": True, + "thread_id": "some-thread-id", + }, + ) + assert body.context is not None + assert body.context["model_name"] == "deepseek-v3" + assert body.context["is_plan_mode"] is True + assert body.context["subagent_enabled"] is True + + +def test_run_create_request_context_defaults_to_none(): + body = RunCreateRequest(input=None) + assert body.context is None + + +def test_context_merges_into_configurable(): + config = _build_runnable_config( + "thread-1", + None, + None, + context={ + "model_name": "deepseek-v3", + "mode": "ultra", + "reasoning_effort": "high", + "thinking_enabled": True, + "is_plan_mode": True, + "subagent_enabled": True, + "max_concurrent_subagents": 5, + "thread_id": "should-be-ignored", + }, + ) + assert config["configurable"]["model_name"] == "deepseek-v3" + assert config["configurable"]["thinking_enabled"] is True + assert config["configurable"]["is_plan_mode"] is True + assert config["configurable"]["subagent_enabled"] is True + assert config["configurable"]["max_concurrent_subagents"] == 5 + assert config["configurable"]["reasoning_effort"] == "high" + assert config["configurable"]["mode"] == "ultra" + assert config["configurable"]["thread_id"] == "thread-1" + + +def test_context_does_not_override_existing_configurable(): + config = _build_runnable_config( + "thread-1", + {"configurable": {"model_name": "gpt-4", "is_plan_mode": False}}, + None, + context={ + "model_name": "deepseek-v3", + "is_plan_mode": True, + "subagent_enabled": True, + }, + ) + assert config["configurable"]["model_name"] == "gpt-4" + assert config["configurable"]["is_plan_mode"] is False + assert config["configurable"]["subagent_enabled"] is True + + +def test_build_runnable_config_with_context_wrapper_in_request_config(): + config = _build_runnable_config( + "thread-1", + {"context": {"user_id": "u-42", "thread_id": "thread-1"}}, + None, + ) + assert "context" in config + assert config["context"]["user_id"] == "u-42" + assert "configurable" not in config + assert config["recursion_limit"] == 100 + + +def test_build_runnable_config_context_plus_configurable_prefers_context(): + config = _build_runnable_config( + "thread-1", + { + "context": {"user_id": "u-42"}, + "configurable": {"model_name": "gpt-4"}, + }, + None, + ) + assert "context" in config + assert config["context"]["user_id"] == "u-42" + assert "configurable" not in config + + +def test_build_runnable_config_context_passthrough_other_keys(): + config = _build_runnable_config( + "thread-1", + {"context": {"thread_id": "thread-1"}, "tags": ["prod"]}, + None, + ) + assert config["context"]["thread_id"] == "thread-1" + assert "configurable" not in config + assert config["tags"] == ["prod"] + + +def test_build_runnable_config_no_request_config(): + config = _build_runnable_config("thread-abc", None, None) + assert config["configurable"] == {"thread_id": "thread-abc"} + assert "context" not in config + + +def test_request_adapter_create_background(): + adapted = adapt_create_run_request(thread_id="thread-1", body={"input": {"x": 1}}) + assert adapted.intent == "create_background" + assert adapted.thread_id == "thread-1" + assert adapted.run_id is None + + +def test_request_adapter_create_stream(): + adapted = adapt_create_stream_request(thread_id=None, body={"input": {"x": 1}}) + assert adapted.intent == "create_and_stream" + assert adapted.thread_id is None + assert adapted.is_stateless is True + + +def test_request_adapter_create_wait(): + adapted = adapt_create_wait_request(thread_id="thread-1", body={}) + assert adapted.intent == "create_and_wait" + assert adapted.thread_id == "thread-1" + + +def test_request_adapter_join_stream(): + adapted = adapt_join_stream_request(thread_id="thread-1", run_id="run-1", headers={"Last-Event-ID": "123"}) + assert adapted.intent == "join_stream" + assert adapted.last_event_id == "123" + + +def test_request_adapter_join_wait(): + adapted = adapt_join_wait_request(thread_id="thread-1", run_id="run-1") + assert adapted.intent == "join_wait" + assert adapted.run_id == "run-1" diff --git a/backend/tests/test_guardrail_middleware.py b/backend/tests/unittest/test_guardrail_middleware.py similarity index 100% rename from backend/tests/test_guardrail_middleware.py rename to backend/tests/unittest/test_guardrail_middleware.py diff --git a/backend/tests/test_harness_boundary.py b/backend/tests/unittest/test_harness_boundary.py similarity index 100% rename from backend/tests/test_harness_boundary.py rename to backend/tests/unittest/test_harness_boundary.py diff --git a/backend/tests/test_infoquest_client.py b/backend/tests/unittest/test_infoquest_client.py similarity index 100% rename from backend/tests/test_infoquest_client.py rename to backend/tests/unittest/test_infoquest_client.py diff --git a/backend/tests/test_initialize_admin.py b/backend/tests/unittest/test_initialize_admin.py similarity index 83% rename from backend/tests/test_initialize_admin.py rename to backend/tests/unittest/test_initialize_admin.py index 17bfaf0b6..906c1d14b 100644 --- a/backend/tests/test_initialize_admin.py +++ b/backend/tests/unittest/test_initialize_admin.py @@ -5,7 +5,6 @@ initialized, password strength validation, and public accessibility (no auth cookie required). """ -import asyncio import os import pytest @@ -13,41 +12,32 @@ from fastapi.testclient import TestClient os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-initialize-admin-min-32") -from app.gateway.auth.config import AuthConfig, set_auth_config +from store.config.app_config import AppConfig, set_app_config +from store.config.storage_config import StorageConfig +from app.plugins.auth.domain.config import AuthConfig +from app.plugins.auth.runtime.config_state import set_auth_config _TEST_SECRET = "test-secret-key-initialize-admin-min-32" @pytest.fixture(autouse=True) def _setup_auth(tmp_path): - """Fresh SQLite engine + auth config per test.""" - from app.gateway import deps - from deerflow.persistence.engine import close_engine, init_engine - + """Fresh SQLite app config + auth config per test.""" set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) - url = f"sqlite+aiosqlite:///{tmp_path}/init_admin.db" - asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))) - deps._cached_local_provider = None - deps._cached_repo = None - try: - yield - finally: - deps._cached_local_provider = None - deps._cached_repo = None - asyncio.run(close_engine()) + set_app_config(AppConfig(storage=StorageConfig(driver="sqlite", sqlite_dir=str(tmp_path)))) + yield @pytest.fixture() def client(_setup_auth): from app.gateway.app import create_app - from app.gateway.auth.config import AuthConfig, set_auth_config + from app.plugins.auth.domain.config import AuthConfig + from app.plugins.auth.runtime.config_state import set_auth_config set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) app = create_app() - # Do NOT use TestClient as a context manager — that would trigger the - # full lifespan which requires config.yaml. The auth endpoints work - # without the lifespan (persistence engine is set up by _setup_auth). - yield TestClient(app) + with TestClient(app) as test_client: + yield test_client def _init_payload(**extra): @@ -110,11 +100,7 @@ def test_initialize_register_does_not_block_initialization(client): def test_initialize_accessible_without_cookie(client): """No access_token cookie needed for /initialize.""" - resp = client.post( - "/api/v1/auth/initialize", - json=_init_payload(), - cookies={}, - ) + resp = client.post("/api/v1/auth/initialize", json=_init_payload()) assert resp.status_code == 201 diff --git a/backend/tests/test_invoke_acp_agent_tool.py b/backend/tests/unittest/test_invoke_acp_agent_tool.py similarity index 99% rename from backend/tests/test_invoke_acp_agent_tool.py rename to backend/tests/unittest/test_invoke_acp_agent_tool.py index 3c5f6f0ff..13cd875fb 100644 --- a/backend/tests/test_invoke_acp_agent_tool.py +++ b/backend/tests/unittest/test_invoke_acp_agent_tool.py @@ -152,7 +152,7 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path): def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path): """P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/.""" from deerflow.config import paths as paths_module - from deerflow.runtime import user_context as uc_module + from deerflow.runtime import actor_context as uc_module monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path)) monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None) @@ -312,7 +312,7 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path): async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path): """P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace.""" from deerflow.config import paths as paths_module - from deerflow.runtime import user_context as uc_module + from deerflow.runtime import actor_context as uc_module monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path)) monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None) diff --git a/backend/tests/test_jina_client.py b/backend/tests/unittest/test_jina_client.py similarity index 100% rename from backend/tests/test_jina_client.py rename to backend/tests/unittest/test_jina_client.py diff --git a/backend/tests/test_langgraph_auth.py b/backend/tests/unittest/test_langgraph_auth.py similarity index 69% rename from backend/tests/test_langgraph_auth.py rename to backend/tests/unittest/test_langgraph_auth.py index 52d215751..cc9e71ab9 100644 --- a/backend/tests/test_langgraph_auth.py +++ b/backend/tests/unittest/test_langgraph_auth.py @@ -9,19 +9,24 @@ import os from datetime import timedelta from pathlib import Path from types import SimpleNamespace -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock from uuid import uuid4 import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32") from langgraph_sdk import Auth -from app.gateway.auth.config import AuthConfig, set_auth_config -from app.gateway.auth.jwt import create_access_token, decode_token -from app.gateway.auth.models import User -from app.gateway.langgraph_auth import add_owner_filter, authenticate +from app.plugins.auth.domain.config import AuthConfig +from app.plugins.auth.domain.jwt import create_access_token, decode_token +from app.plugins.auth.security.langgraph import add_owner_filter, authenticate +from app.plugins.auth.domain.models import User as AuthUser +from app.plugins.auth.runtime.config_state import set_auth_config +from app.plugins.auth.storage import DbUserRepository, UserCreate +from store.persistence import MappedBase +from app.plugins.auth.storage.models import User as UserModel # noqa: F401 # ── Helpers ─────────────────────────────────────────────────────────────── @@ -40,13 +45,40 @@ def _req(cookies=None, method="GET", headers=None): def _user(user_id=None, token_version=0): - return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version) + return AuthUser(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version) -def _mock_provider(user=None): - p = AsyncMock() - p.get_user = AsyncMock(return_value=user) - return p +async def _attach_auth_session(request, tmp_path, user: AuthUser | None = None): + engine = create_async_engine( + f"sqlite+aiosqlite:///{tmp_path / 'langgraph-auth.db'}", + future=True, + ) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + session = session_factory() + if user is not None: + repo = DbUserRepository(session) + await repo.create_user( + UserCreate( + id=str(user.id), + email=user.email, + password_hash=user.password_hash, + system_role=user.system_role, + oauth_provider=user.oauth_provider, + oauth_id=user.oauth_id, + needs_setup=user.needs_setup, + token_version=user.token_version, + ) + ) + await session.commit() + request._auth_session = session + return engine, session # ── @auth.authenticate ─────────────────────────────────────────────────── @@ -73,77 +105,103 @@ def test_expired_jwt_raises_401(): assert exc.value.status_code == 401 -def test_user_not_found_raises_401(): +@pytest.mark.anyio +async def test_user_not_found_raises_401(tmp_path): token = create_access_token("ghost") - with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)): + request = _req({"access_token": token}) + engine, session = await _attach_auth_session(request, tmp_path) + try: with pytest.raises(Auth.exceptions.HTTPException) as exc: - asyncio.run(authenticate(_req({"access_token": token}))) + await authenticate(request) assert exc.value.status_code == 401 assert "User not found" in str(exc.value.detail) + finally: + await session.close() + await engine.dispose() -def test_token_version_mismatch_raises_401(): +@pytest.mark.anyio +async def test_token_version_mismatch_raises_401(tmp_path): user = _user(token_version=2) token = create_access_token(str(user.id), token_version=1) - with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): + request = _req({"access_token": token}) + engine, session = await _attach_auth_session(request, tmp_path, user) + try: with pytest.raises(Auth.exceptions.HTTPException) as exc: - asyncio.run(authenticate(_req({"access_token": token}))) + await authenticate(request) assert exc.value.status_code == 401 assert "revoked" in str(exc.value.detail).lower() + finally: + await session.close() + await engine.dispose() -def test_valid_token_returns_user_id(): +@pytest.mark.anyio +async def test_valid_token_returns_user_id(tmp_path): user = _user(token_version=0) token = create_access_token(str(user.id), token_version=0) - with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): - result = asyncio.run(authenticate(_req({"access_token": token}))) + request = _req({"access_token": token}) + engine, session = await _attach_auth_session(request, tmp_path, user) + try: + result = await authenticate(request) + finally: + await session.close() + await engine.dispose() assert result == str(user.id) -def test_valid_token_matching_version(): +@pytest.mark.anyio +async def test_valid_token_matching_version(tmp_path): user = _user(token_version=5) token = create_access_token(str(user.id), token_version=5) - with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): - result = asyncio.run(authenticate(_req({"access_token": token}))) + request = _req({"access_token": token}) + engine, session = await _attach_auth_session(request, tmp_path, user) + try: + result = await authenticate(request) + finally: + await session.close() + await engine.dispose() assert result == str(user.id) # ── @auth.authenticate edge cases ──────────────────────────────────────── -def test_provider_exception_propagates(): - """Provider raises → should not be swallowed silently.""" - token = create_access_token("user-1") - p = AsyncMock() - p.get_user = AsyncMock(side_effect=RuntimeError("DB down")) - with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p): - with pytest.raises(RuntimeError, match="DB down"): - asyncio.run(authenticate(_req({"access_token": token}))) - - -def test_jwt_missing_ver_defaults_to_zero(): +@pytest.mark.anyio +async def test_jwt_missing_ver_defaults_to_zero(tmp_path): """JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0.""" import jwt as pyjwt uid = str(uuid4()) raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256") user = _user(user_id=uid, token_version=0) - with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): - result = asyncio.run(authenticate(_req({"access_token": raw}))) + request = _req({"access_token": raw}) + engine, session = await _attach_auth_session(request, tmp_path, user) + try: + result = await authenticate(request) + finally: + await session.close() + await engine.dispose() assert result == uid -def test_jwt_missing_ver_rejected_when_user_version_nonzero(): +@pytest.mark.anyio +async def test_jwt_missing_ver_rejected_when_user_version_nonzero(tmp_path): """JWT without 'ver' (defaults 0) vs user with token_version=1 → 401.""" import jwt as pyjwt uid = str(uuid4()) raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256") user = _user(user_id=uid, token_version=1) - with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): + request = _req({"access_token": raw}) + engine, session = await _attach_auth_session(request, tmp_path, user) + try: with pytest.raises(Auth.exceptions.HTTPException) as exc: - asyncio.run(authenticate(_req({"access_token": raw}))) + await authenticate(request) assert exc.value.status_code == 401 + finally: + await session.close() + await engine.dispose() def test_wrong_secret_raises_401(): @@ -223,7 +281,7 @@ def test_filter_with_empty_metadata(): def test_shared_jwt_secret(): token = create_access_token("user-1", token_version=3) payload = decode_token(token) - from app.gateway.auth.errors import TokenError + from app.plugins.auth.domain.errors import TokenError assert not isinstance(payload, TokenError) assert payload.sub == "user-1" @@ -233,13 +291,13 @@ def test_shared_jwt_secret(): def test_langgraph_json_has_auth_path(): import json - config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text()) - assert "auth" in config - assert "langgraph_auth" in config["auth"]["path"] + config = json.loads((Path(__file__).resolve().parents[2] / "langgraph.json").read_text()) + assert "graphs" in config + assert "lead_agent" in config["graphs"] def test_auth_handler_has_both_layers(): - from app.gateway.langgraph_auth import auth + from app.plugins.auth.security.langgraph import auth assert auth._authenticate_handler is not None assert len(auth._global_handlers) == 1 diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/unittest/test_lead_agent_model_resolution.py similarity index 100% rename from backend/tests/test_lead_agent_model_resolution.py rename to backend/tests/unittest/test_lead_agent_model_resolution.py diff --git a/backend/tests/test_lead_agent_prompt.py b/backend/tests/unittest/test_lead_agent_prompt.py similarity index 100% rename from backend/tests/test_lead_agent_prompt.py rename to backend/tests/unittest/test_lead_agent_prompt.py diff --git a/backend/tests/test_lead_agent_skills.py b/backend/tests/unittest/test_lead_agent_skills.py similarity index 100% rename from backend/tests/test_lead_agent_skills.py rename to backend/tests/unittest/test_lead_agent_skills.py diff --git a/backend/tests/test_llm_error_handling_middleware.py b/backend/tests/unittest/test_llm_error_handling_middleware.py similarity index 100% rename from backend/tests/test_llm_error_handling_middleware.py rename to backend/tests/unittest/test_llm_error_handling_middleware.py diff --git a/backend/tests/test_local_bash_tool_loading.py b/backend/tests/unittest/test_local_bash_tool_loading.py similarity index 100% rename from backend/tests/test_local_bash_tool_loading.py rename to backend/tests/unittest/test_local_bash_tool_loading.py diff --git a/backend/tests/test_local_sandbox_encoding.py b/backend/tests/unittest/test_local_sandbox_encoding.py similarity index 100% rename from backend/tests/test_local_sandbox_encoding.py rename to backend/tests/unittest/test_local_sandbox_encoding.py diff --git a/backend/tests/test_local_sandbox_provider_mounts.py b/backend/tests/unittest/test_local_sandbox_provider_mounts.py similarity index 100% rename from backend/tests/test_local_sandbox_provider_mounts.py rename to backend/tests/unittest/test_local_sandbox_provider_mounts.py diff --git a/backend/tests/test_loop_detection_middleware.py b/backend/tests/unittest/test_loop_detection_middleware.py similarity index 100% rename from backend/tests/test_loop_detection_middleware.py rename to backend/tests/unittest/test_loop_detection_middleware.py diff --git a/backend/tests/test_mcp_client_config.py b/backend/tests/unittest/test_mcp_client_config.py similarity index 100% rename from backend/tests/test_mcp_client_config.py rename to backend/tests/unittest/test_mcp_client_config.py diff --git a/backend/tests/test_mcp_oauth.py b/backend/tests/unittest/test_mcp_oauth.py similarity index 100% rename from backend/tests/test_mcp_oauth.py rename to backend/tests/unittest/test_mcp_oauth.py diff --git a/backend/tests/test_mcp_sync_wrapper.py b/backend/tests/unittest/test_mcp_sync_wrapper.py similarity index 100% rename from backend/tests/test_mcp_sync_wrapper.py rename to backend/tests/unittest/test_mcp_sync_wrapper.py diff --git a/backend/tests/test_memory_prompt_injection.py b/backend/tests/unittest/test_memory_prompt_injection.py similarity index 100% rename from backend/tests/test_memory_prompt_injection.py rename to backend/tests/unittest/test_memory_prompt_injection.py diff --git a/backend/tests/test_memory_queue.py b/backend/tests/unittest/test_memory_queue.py similarity index 100% rename from backend/tests/test_memory_queue.py rename to backend/tests/unittest/test_memory_queue.py diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/unittest/test_memory_queue_user_isolation.py similarity index 100% rename from backend/tests/test_memory_queue_user_isolation.py rename to backend/tests/unittest/test_memory_queue_user_isolation.py diff --git a/backend/tests/test_memory_router.py b/backend/tests/unittest/test_memory_router.py similarity index 100% rename from backend/tests/test_memory_router.py rename to backend/tests/unittest/test_memory_router.py diff --git a/backend/tests/test_memory_storage.py b/backend/tests/unittest/test_memory_storage.py similarity index 100% rename from backend/tests/test_memory_storage.py rename to backend/tests/unittest/test_memory_storage.py diff --git a/backend/tests/test_memory_storage_user_isolation.py b/backend/tests/unittest/test_memory_storage_user_isolation.py similarity index 100% rename from backend/tests/test_memory_storage_user_isolation.py rename to backend/tests/unittest/test_memory_storage_user_isolation.py diff --git a/backend/tests/test_memory_updater.py b/backend/tests/unittest/test_memory_updater.py similarity index 100% rename from backend/tests/test_memory_updater.py rename to backend/tests/unittest/test_memory_updater.py diff --git a/backend/tests/test_memory_updater_user_isolation.py b/backend/tests/unittest/test_memory_updater_user_isolation.py similarity index 100% rename from backend/tests/test_memory_updater_user_isolation.py rename to backend/tests/unittest/test_memory_updater_user_isolation.py diff --git a/backend/tests/test_memory_upload_filtering.py b/backend/tests/unittest/test_memory_upload_filtering.py similarity index 100% rename from backend/tests/test_memory_upload_filtering.py rename to backend/tests/unittest/test_memory_upload_filtering.py diff --git a/backend/tests/test_migration_user_isolation.py b/backend/tests/unittest/test_migration_user_isolation.py similarity index 100% rename from backend/tests/test_migration_user_isolation.py rename to backend/tests/unittest/test_migration_user_isolation.py diff --git a/backend/tests/test_model_config.py b/backend/tests/unittest/test_model_config.py similarity index 100% rename from backend/tests/test_model_config.py rename to backend/tests/unittest/test_model_config.py diff --git a/backend/tests/test_model_factory.py b/backend/tests/unittest/test_model_factory.py similarity index 100% rename from backend/tests/test_model_factory.py rename to backend/tests/unittest/test_model_factory.py diff --git a/backend/tests/unittest/test_owner_isolation.py b/backend/tests/unittest/test_owner_isolation.py new file mode 100644 index 000000000..2d83a913e --- /dev/null +++ b/backend/tests/unittest/test_owner_isolation.py @@ -0,0 +1,236 @@ +"""Cross-user isolation tests for current app-owned storage adapters. + +These tests exercise isolation by binding different ``ActorContext`` +values around the app-layer storage adapters. The safety property is: + + data written under user A is not visible to user B through the same + adapter surface unless a call explicitly opts out with ``user_id=None``. +""" + +from __future__ import annotations + +from contextlib import contextmanager + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.infra.storage import AppRunEventStore, FeedbackStoreAdapter, RunStoreAdapter, ThreadMetaStorage, ThreadMetaStoreAdapter +from deerflow.runtime.actor_context import AUTO, ActorContext, bind_actor_context, reset_actor_context +from store.persistence import MappedBase + + +USER_A = "user-a" +USER_B = "user-b" + + +async def _make_components(tmp_path): + engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'isolation.db'}", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + + thread_store = ThreadMetaStorage(ThreadMetaStoreAdapter(session_factory)) + return ( + engine, + thread_store, + RunStoreAdapter(session_factory), + FeedbackStoreAdapter(session_factory), + AppRunEventStore(session_factory), + ) + + +@contextmanager +def _as_user(user_id: str): + token = bind_actor_context(ActorContext(user_id=user_id)) + try: + yield + finally: + reset_actor_context(token) + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_thread_meta_cross_user_isolation(tmp_path): + engine, thread_store, _, _, _ = await _make_components(tmp_path) + try: + with _as_user(USER_A): + await thread_store.ensure_thread(thread_id="t-alpha") + with _as_user(USER_B): + await thread_store.ensure_thread(thread_id="t-beta") + + with _as_user(USER_A): + assert (await thread_store.get_thread("t-alpha")) is not None + assert await thread_store.get_thread("t-beta") is None + rows = await thread_store.search_threads() + assert [row.thread_id for row in rows] == ["t-alpha"] + + with _as_user(USER_B): + assert (await thread_store.get_thread("t-beta")) is not None + assert await thread_store.get_thread("t-alpha") is None + rows = await thread_store.search_threads() + assert [row.thread_id for row in rows] == ["t-beta"] + finally: + await engine.dispose() + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_runs_cross_user_isolation(tmp_path): + engine, thread_store, run_store, _, _ = await _make_components(tmp_path) + try: + with _as_user(USER_A): + await thread_store.ensure_thread(thread_id="t-alpha") + await run_store.create("run-a1", "t-alpha") + await run_store.create("run-a2", "t-alpha") + + with _as_user(USER_B): + await thread_store.ensure_thread(thread_id="t-beta") + await run_store.create("run-b1", "t-beta") + + with _as_user(USER_A): + assert (await run_store.get("run-a1")) is not None + assert await run_store.get("run-b1") is None + rows = await run_store.list_by_thread("t-alpha") + assert {row["run_id"] for row in rows} == {"run-a1", "run-a2"} + assert await run_store.list_by_thread("t-beta") == [] + + with _as_user(USER_B): + assert await run_store.get("run-a1") is None + rows = await run_store.list_by_thread("t-beta") + assert [row["run_id"] for row in rows] == ["run-b1"] + finally: + await engine.dispose() + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_run_events_cross_user_isolation(tmp_path): + engine, thread_store, _, _, event_store = await _make_components(tmp_path) + try: + with _as_user(USER_A): + await thread_store.ensure_thread(thread_id="t-alpha") + await event_store.put_batch( + [ + { + "thread_id": "t-alpha", + "run_id": "run-a1", + "event_type": "human_message", + "category": "message", + "content": "User A private question", + }, + { + "thread_id": "t-alpha", + "run_id": "run-a1", + "event_type": "ai_message", + "category": "message", + "content": "User A private answer", + }, + ] + ) + + with _as_user(USER_B): + await thread_store.ensure_thread(thread_id="t-beta") + await event_store.put_batch( + [ + { + "thread_id": "t-beta", + "run_id": "run-b1", + "event_type": "human_message", + "category": "message", + "content": "User B private question", + } + ] + ) + + with _as_user(USER_A): + msgs = await event_store.list_messages("t-alpha") + contents = [msg["content"] for msg in msgs] + assert "User A private question" in contents + assert "User A private answer" in contents + assert "User B private question" not in contents + assert await event_store.list_messages("t-beta") == [] + assert await event_store.list_events("t-beta", "run-b1") == [] + assert await event_store.count_messages("t-beta") == 0 + + with _as_user(USER_B): + msgs = await event_store.list_messages("t-beta") + contents = [msg["content"] for msg in msgs] + assert "User B private question" in contents + assert "User A private question" not in contents + assert await event_store.count_messages("t-alpha") == 0 + finally: + await engine.dispose() + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_feedback_cross_user_isolation(tmp_path): + engine, thread_store, _, feedback_store, _ = await _make_components(tmp_path) + try: + with _as_user(USER_A): + await thread_store.ensure_thread(thread_id="t-alpha") + a_feedback = await feedback_store.create( + run_id="run-a1", + thread_id="t-alpha", + rating=1, + user_id=USER_A, + comment="A liked this", + ) + + with _as_user(USER_B): + await thread_store.ensure_thread(thread_id="t-beta") + b_feedback = await feedback_store.create( + run_id="run-b1", + thread_id="t-beta", + rating=-1, + user_id=USER_B, + comment="B disliked this", + ) + + with _as_user(USER_A): + assert (await feedback_store.get(a_feedback["feedback_id"])) is not None + assert await feedback_store.get(b_feedback["feedback_id"]) is not None + assert await feedback_store.list_by_run("t-beta", "run-b1", user_id=USER_A) == [] + + with _as_user(USER_B): + assert await feedback_store.list_by_run("t-alpha", "run-a1", user_id=USER_B) == [] + rows = await feedback_store.list_by_run("t-beta", "run-b1", user_id=USER_B) + assert len(rows) == 1 + assert rows[0]["comment"] == "B disliked this" + finally: + await engine.dispose() + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_repository_without_context_raises(tmp_path): + engine, thread_store, _, _, _ = await _make_components(tmp_path) + try: + with pytest.raises(RuntimeError, match="no actor context is set"): + await thread_store.search_threads(user_id=AUTO) + finally: + await engine.dispose() + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_explicit_none_bypasses_filter(tmp_path): + engine, thread_store, _, _, _ = await _make_components(tmp_path) + try: + with _as_user(USER_A): + await thread_store.ensure_thread(thread_id="t-alpha") + with _as_user(USER_B): + await thread_store.ensure_thread(thread_id="t-beta") + + rows = await thread_store.search_threads(user_id=None) + assert {row.thread_id for row in rows} == {"t-alpha", "t-beta"} + assert await thread_store.get_thread("t-alpha", user_id=None) is not None + assert await thread_store.get_thread("t-beta", user_id=None) is not None + finally: + await engine.dispose() diff --git a/backend/tests/test_patched_deepseek.py b/backend/tests/unittest/test_patched_deepseek.py similarity index 100% rename from backend/tests/test_patched_deepseek.py rename to backend/tests/unittest/test_patched_deepseek.py diff --git a/backend/tests/test_patched_minimax.py b/backend/tests/unittest/test_patched_minimax.py similarity index 100% rename from backend/tests/test_patched_minimax.py rename to backend/tests/unittest/test_patched_minimax.py diff --git a/backend/tests/test_patched_openai.py b/backend/tests/unittest/test_patched_openai.py similarity index 100% rename from backend/tests/test_patched_openai.py rename to backend/tests/unittest/test_patched_openai.py diff --git a/backend/tests/test_paths_user_isolation.py b/backend/tests/unittest/test_paths_user_isolation.py similarity index 100% rename from backend/tests/test_paths_user_isolation.py rename to backend/tests/unittest/test_paths_user_isolation.py diff --git a/backend/tests/test_present_file_tool_core_logic.py b/backend/tests/unittest/test_present_file_tool_core_logic.py similarity index 100% rename from backend/tests/test_present_file_tool_core_logic.py rename to backend/tests/unittest/test_present_file_tool_core_logic.py diff --git a/backend/tests/test_provisioner_kubeconfig.py b/backend/tests/unittest/test_provisioner_kubeconfig.py similarity index 100% rename from backend/tests/test_provisioner_kubeconfig.py rename to backend/tests/unittest/test_provisioner_kubeconfig.py diff --git a/backend/tests/test_provisioner_pvc_volumes.py b/backend/tests/unittest/test_provisioner_pvc_volumes.py similarity index 100% rename from backend/tests/test_provisioner_pvc_volumes.py rename to backend/tests/unittest/test_provisioner_pvc_volumes.py diff --git a/backend/tests/test_readability.py b/backend/tests/unittest/test_readability.py similarity index 100% rename from backend/tests/test_readability.py rename to backend/tests/unittest/test_readability.py diff --git a/backend/tests/test_reflection_resolvers.py b/backend/tests/unittest/test_reflection_resolvers.py similarity index 100% rename from backend/tests/test_reflection_resolvers.py rename to backend/tests/unittest/test_reflection_resolvers.py diff --git a/backend/tests/unittest/test_run_callbacks_builder.py b/backend/tests/unittest/test_run_callbacks_builder.py new file mode 100644 index 000000000..f6efb62d1 --- /dev/null +++ b/backend/tests/unittest/test_run_callbacks_builder.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from langchain_core.messages import HumanMessage + +from deerflow.runtime.runs.callbacks.builder import build_run_callbacks +from deerflow.runtime.runs.types import RunRecord, RunStatus + + +def _record() -> RunRecord: + return RunRecord( + run_id="run-1", + thread_id="thread-1", + assistant_id=None, + status=RunStatus.pending, + temporary=False, + multitask_strategy="reject", + metadata={}, + created_at="", + updated_at="", + ) + + +def test_build_run_callbacks_sets_first_human_message_from_string_content(): + artifacts = build_run_callbacks( + record=_record(), + graph_input={"messages": [HumanMessage(content="hello world")]}, + event_store=None, + ) + + assert artifacts.completion_data().first_human_message == "hello world" + + +def test_build_run_callbacks_sets_first_human_message_from_content_blocks(): + artifacts = build_run_callbacks( + record=_record(), + graph_input={ + "messages": [ + HumanMessage( + content=[ + {"type": "text", "text": "hello "}, + {"type": "text", "text": "world"}, + ] + ) + ] + }, + event_store=None, + ) + + assert artifacts.completion_data().first_human_message == "hello world" + + +def test_build_run_callbacks_sets_first_human_message_from_dict_payload(): + artifacts = build_run_callbacks( + record=_record(), + graph_input={ + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "hi from dict"}], + } + ] + }, + event_store=None, + ) + + assert artifacts.completion_data().first_human_message == "hi from dict" diff --git a/backend/tests/unittest/test_run_create_store.py b/backend/tests/unittest/test_run_create_store.py new file mode 100644 index 000000000..38dbd824e --- /dev/null +++ b/backend/tests/unittest/test_run_create_store.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from app.gateway.services.runs.store.create_store import AppRunCreateStore +from deerflow.runtime.runs.types import RunRecord, RunStatus + + +@pytest.mark.anyio +async def test_create_run_syncs_thread_meta_assistant_id(): + repo = AsyncMock() + thread_meta_storage = AsyncMock() + thread_meta_storage.ensure_thread.return_value.assistant_id = None + + store = AppRunCreateStore(repo, thread_meta_storage=thread_meta_storage) + record = RunRecord( + run_id="run-1", + thread_id="thread-1", + assistant_id="lead_agent", + status=RunStatus.pending, + temporary=False, + multitask_strategy="reject", + ) + + await store.create_run(record) + + repo.create.assert_awaited_once() + thread_meta_storage.ensure_thread.assert_awaited_once_with( + thread_id="thread-1", + assistant_id="lead_agent", + ) + thread_meta_storage.sync_thread_assistant_id.assert_awaited_once_with( + thread_id="thread-1", + assistant_id="lead_agent", + ) diff --git a/backend/tests/unittest/test_run_event_store.py b/backend/tests/unittest/test_run_event_store.py new file mode 100644 index 000000000..73464ebb5 --- /dev/null +++ b/backend/tests/unittest/test_run_event_store.py @@ -0,0 +1,275 @@ +"""Tests for current run event store backends.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.infra.run_events import JsonlRunEventStore, build_run_event_store +from app.infra.storage import AppRunEventStore, ThreadMetaStorage, ThreadMetaStoreAdapter +from deerflow.runtime.actor_context import ActorContext, bind_actor_context, reset_actor_context +from store.persistence import MappedBase + + +@pytest.fixture +def jsonl_store(tmp_path): + return JsonlRunEventStore(base_dir=tmp_path / "jsonl") + + +async def _make_db_store(tmp_path): + engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + thread_store = ThreadMetaStorage(ThreadMetaStoreAdapter(session_factory)) + return engine, thread_store, AppRunEventStore(session_factory), session_factory + + +class _RunEventStoreContract: + async def _exercise_basic_contract(self, store): + first = await store.put_batch( + [ + {"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message", "content": "a"}, + {"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message", "content": "b"}, + {"thread_id": "t1", "run_id": "r1", "event_type": "llm_end", "category": "trace", "metadata": {"m": 1}}, + ] + ) + assert [row["seq"] for row in first] == [1, 2, 3] + + messages = await store.list_messages("t1") + assert [row["seq"] for row in messages] == [1, 2] + assert messages[0]["content"] == "a" + + events = await store.list_events("t1", "r1") + assert len(events) == 3 + + by_run = await store.list_messages_by_run("t1", "r1") + assert [row["seq"] for row in by_run] == [1, 2] + assert await store.count_messages("t1") == 2 + + deleted = await store.delete_by_run("t1", "r1") + assert deleted == 3 + assert await store.list_messages("t1") == [] + + +class TestJsonlRunEventStore(_RunEventStoreContract): + @pytest.mark.anyio + async def test_basic_contract(self, jsonl_store): + await self._exercise_basic_contract(jsonl_store) + + @pytest.mark.anyio + async def test_file_at_correct_path(self, tmp_path): + store = JsonlRunEventStore(base_dir=tmp_path / "jsonl") + await store.put_batch( + [{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message"}] + ) + assert (tmp_path / "jsonl" / "threads" / "t1" / "events.jsonl").exists() + + +class TestAppRunEventStore(_RunEventStoreContract): + @pytest.mark.anyio + async def test_basic_contract(self, tmp_path): + engine, thread_store, store, _ = await _make_db_store(tmp_path) + try: + await thread_store.ensure_thread(thread_id="t1", user_id=None) + await self._exercise_basic_contract(store) + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_actor_isolation_by_thread_owner(self, tmp_path): + engine, thread_store, store, _ = await _make_db_store(tmp_path) + try: + token = bind_actor_context(ActorContext(user_id="user-a")) + try: + await thread_store.ensure_thread(thread_id="t-alpha") + await store.put_batch( + [ + { + "thread_id": "t-alpha", + "run_id": "run-a1", + "event_type": "human_message", + "category": "message", + "content": "private-a", + } + ] + ) + finally: + reset_actor_context(token) + + token = bind_actor_context(ActorContext(user_id="user-b")) + try: + await thread_store.ensure_thread(thread_id="t-beta") + await store.put_batch( + [ + { + "thread_id": "t-beta", + "run_id": "run-b1", + "event_type": "human_message", + "category": "message", + "content": "private-b", + } + ] + ) + assert await store.list_messages("t-alpha") == [] + assert await store.list_events("t-alpha", "run-a1") == [] + assert await store.count_messages("t-alpha") == 0 + assert await store.delete_by_thread("t-alpha") == 0 + finally: + reset_actor_context(token) + + token = bind_actor_context(ActorContext(user_id="user-a")) + try: + rows = await store.list_messages("t-alpha") + assert [row["content"] for row in rows] == ["private-a"] + finally: + reset_actor_context(token) + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_put_batch_preserves_structured_content_metadata_and_created_at(self, tmp_path): + engine, thread_store, store, _ = await _make_db_store(tmp_path) + try: + await thread_store.ensure_thread(thread_id="t1", user_id=None) + created_at = datetime(2026, 4, 20, 8, 30, tzinfo=UTC) + rows = await store.put_batch( + [ + { + "thread_id": "t1", + "run_id": "r1", + "event_type": "tool_end", + "category": "trace", + "content": {"type": "tool", "content": "ok"}, + "metadata": {"tool": "search"}, + "created_at": created_at.isoformat(), + } + ] + ) + + assert rows[0]["content"] == {"type": "tool", "content": "ok"} + assert rows[0]["metadata"]["tool"] == "search" + assert "content_is_dict" not in rows[0]["metadata"] + assert rows[0]["created_at"] == created_at.isoformat() + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_list_messages_supports_before_and_after_pagination(self, tmp_path): + engine, thread_store, store, _ = await _make_db_store(tmp_path) + try: + await thread_store.ensure_thread(thread_id="t1", user_id=None) + await store.put_batch( + [ + { + "thread_id": "t1", + "run_id": "r1", + "event_type": "human_message", + "category": "message", + "content": str(i), + } + for i in range(10) + ] + ) + + before = await store.list_messages("t1", before_seq=6, limit=3) + after = await store.list_messages("t1", after_seq=7, limit=3) + + assert [message["seq"] for message in before] == [3, 4, 5] + assert [message["seq"] for message in after] == [8, 9, 10] + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_list_events_filters_by_run_and_event_type(self, tmp_path): + engine, thread_store, store, _ = await _make_db_store(tmp_path) + try: + await thread_store.ensure_thread(thread_id="t1", user_id=None) + await store.put_batch( + [ + {"thread_id": "t1", "run_id": "r1", "event_type": "llm_start", "category": "trace"}, + {"thread_id": "t1", "run_id": "r1", "event_type": "llm_end", "category": "trace"}, + {"thread_id": "t1", "run_id": "r2", "event_type": "llm_end", "category": "trace"}, + ] + ) + + events = await store.list_events("t1", "r1", event_types=["llm_end"]) + assert len(events) == 1 + assert events[0]["run_id"] == "r1" + assert events[0]["event_type"] == "llm_end" + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_put_batch_denies_write_to_other_users_thread(self, tmp_path): + engine, thread_store, store, _ = await _make_db_store(tmp_path) + try: + token = bind_actor_context(ActorContext(user_id="user-a")) + try: + await thread_store.ensure_thread(thread_id="t-alpha") + finally: + reset_actor_context(token) + + token = bind_actor_context(ActorContext(user_id="user-b")) + try: + with pytest.raises(PermissionError, match="not allowed to append events"): + await store.put_batch( + [ + { + "thread_id": "t-alpha", + "run_id": "run-a1", + "event_type": "human_message", + "category": "message", + "content": "forbidden", + } + ] + ) + finally: + reset_actor_context(token) + finally: + await engine.dispose() + + +class TestBuildRunEventStore: + @pytest.mark.anyio + async def test_db_backend(self, tmp_path, monkeypatch): + from types import SimpleNamespace + + engine, _, _, session_factory = await _make_db_store(tmp_path) + try: + monkeypatch.setattr( + "app.infra.run_events.factory.get_app_config", + lambda: SimpleNamespace(run_events=SimpleNamespace(backend="db", jsonl_base_dir="", max_trace_content=0)), + ) + store = build_run_event_store(session_factory) + assert isinstance(store, AppRunEventStore) + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_jsonl_backend(self, tmp_path, monkeypatch): + from types import SimpleNamespace + + engine, _, _, session_factory = await _make_db_store(tmp_path) + try: + monkeypatch.setattr( + "app.infra.run_events.factory.get_app_config", + lambda: SimpleNamespace( + run_events=SimpleNamespace( + backend="jsonl", + jsonl_base_dir=str(tmp_path / "jsonl"), + max_trace_content=0, + ) + ), + ) + store = build_run_event_store(session_factory) + assert isinstance(store, JsonlRunEventStore) + finally: + await engine.dispose() diff --git a/backend/tests/unittest/test_run_execution_artifacts.py b/backend/tests/unittest/test_run_execution_artifacts.py new file mode 100644 index 000000000..74b5d9530 --- /dev/null +++ b/backend/tests/unittest/test_run_execution_artifacts.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from deerflow.runtime.runs.internal.execution.artifacts import build_run_artifacts + + +class _Agent: + pass + + +def test_build_run_artifacts_uses_store_as_reference_store(): + store = object() + + def agent_factory(*, config): + return _Agent() + + artifacts = build_run_artifacts( + thread_id="thread-1", + run_id="run-1", + checkpointer=None, + store=store, + agent_factory=agent_factory, + config={}, + bridge=None, # type: ignore[arg-type] + ) + + assert artifacts.reference_store is store diff --git a/backend/tests/test_run_manager.py b/backend/tests/unittest/test_run_manager.py similarity index 91% rename from backend/tests/test_run_manager.py rename to backend/tests/unittest/test_run_manager.py index 58ecf1f26..c53cebd7b 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/unittest/test_run_manager.py @@ -75,27 +75,27 @@ async def test_cancel_not_inflight(manager: RunManager): @pytest.mark.anyio async def test_list_by_thread(manager: RunManager): - """Same thread should return multiple runs.""" + """Same thread should return multiple runs newest first.""" r1 = await manager.create("thread-1") r2 = await manager.create("thread-1") await manager.create("thread-2") runs = await manager.list_by_thread("thread-1") assert len(runs) == 2 - assert runs[0].run_id == r1.run_id - assert runs[1].run_id == r2.run_id + assert runs[0].run_id == r2.run_id + assert runs[1].run_id == r1.run_id @pytest.mark.anyio async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch): - """Ordering should be stable (insertion order) even when timestamps tie.""" - monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00") + """Ordering should be stable newest-first even when timestamps tie.""" + monkeypatch.setattr("deerflow.runtime.runs.internal.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00") r1 = await manager.create("thread-1") r2 = await manager.create("thread-1") runs = await manager.list_by_thread("thread-1") - assert [run.run_id for run in runs] == [r1.run_id, r2.run_id] + assert [run.run_id for run in runs] == [r2.run_id, r1.run_id] @pytest.mark.anyio diff --git a/backend/tests/unittest/test_run_repository.py b/backend/tests/unittest/test_run_repository.py new file mode 100644 index 000000000..c21fed770 --- /dev/null +++ b/backend/tests/unittest/test_run_repository.py @@ -0,0 +1,267 @@ +"""Tests for RunStoreAdapter (current SQLAlchemy-backed run store).""" + +from __future__ import annotations + +from contextlib import contextmanager + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.infra.storage import RunStoreAdapter +from deerflow.runtime.actor_context import ActorContext, bind_actor_context, reset_actor_context +from store.persistence import MappedBase + + +async def _make_repo(tmp_path): + engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + return engine, RunStoreAdapter(session_factory) + + +@contextmanager +def _as_user(user_id: str): + token = bind_actor_context(ActorContext(user_id=user_id)) + try: + yield + finally: + reset_actor_context(token) + + +class TestRunStoreAdapter: + @pytest.mark.anyio + async def test_create_and_get(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", status="pending", user_id=None) + row = await repo.get("r1", user_id=None) + assert row is not None + assert row["run_id"] == "r1" + assert row["thread_id"] == "t1" + assert row["status"] == "pending" + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_get_missing_returns_none(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + assert await repo.get("nope", user_id=None) is None + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_update_status(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id=None) + await repo.update_status("r1", "running") + row = await repo.get("r1", user_id=None) + assert row is not None + assert row["status"] == "running" + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_set_error(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id=None) + await repo.set_error("r1", "boom") + row = await repo.get("r1", user_id=None) + assert row is not None + assert row["status"] == "error" + assert row["error"] == "boom" + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_thread(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id=None) + await repo.create("r2", "t1", user_id=None) + await repo.create("r3", "t2", user_id=None) + rows = await repo.list_by_thread("t1", user_id=None) + assert len(rows) == 2 + assert all(r["thread_id"] == "t1" for r in rows) + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_thread_owner_filter(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id="alice") + await repo.create("r2", "t1", user_id="bob") + rows = await repo.list_by_thread("t1", user_id="alice") + assert len(rows) == 1 + assert rows[0]["user_id"] == "alice" + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_delete(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id=None) + assert await repo.delete("r1", user_id=None) is True + assert await repo.get("r1", user_id=None) is None + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_delete_nonexistent_is_false(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + assert await repo.delete("nope", user_id=None) is False + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_update_run_completion(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", status="running", user_id=None) + await repo.update_run_completion( + "r1", + status="success", + total_input_tokens=100, + total_output_tokens=50, + total_tokens=150, + llm_call_count=2, + lead_agent_tokens=120, + subagent_tokens=20, + middleware_tokens=10, + message_count=3, + last_ai_message="The answer is 42", + first_human_message="What is the meaning?", + ) + row = await repo.get("r1", user_id=None) + assert row is not None + assert row["status"] == "success" + assert row["total_tokens"] == 150 + assert row["llm_call_count"] == 2 + assert row["lead_agent_tokens"] == 120 + assert row["message_count"] == 3 + assert row["last_ai_message"] == "The answer is 42" + assert row["first_human_message"] == "What is the meaning?" + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_metadata_preserved(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id=None, metadata={"key": "value"}) + row = await repo.get("r1", user_id=None) + assert row is not None + assert row["metadata"] == {"key": "value"} + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_kwargs_with_non_serializable(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + + class Dummy: + pass + + try: + await repo.create("r1", "t1", user_id=None, kwargs={"obj": Dummy()}) + row = await repo.get("r1", user_id=None) + assert row is not None + assert "obj" in row["kwargs"] + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_update_run_completion_preserves_existing_fields(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", assistant_id="agent1", status="running", user_id=None) + await repo.update_run_completion("r1", status="success", total_tokens=100) + row = await repo.get("r1", user_id=None) + assert row is not None + assert row["thread_id"] == "t1" + assert row["assistant_id"] == "agent1" + assert row["total_tokens"] == 100 + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_list_by_thread_limit(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + for i in range(5): + await repo.create(f"r{i}", "t1", user_id=None) + rows = await repo.list_by_thread("t1", limit=2, user_id=None) + assert len(rows) == 2 + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_owner_none_returns_all(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id="alice") + await repo.create("r2", "t1", user_id="bob") + rows = await repo.list_by_thread("t1", user_id=None) + assert len(rows) == 2 + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_create_uses_actor_context_by_default(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + with _as_user("alice"): + await repo.create("r1", "t1") + row = await repo.get("r1") + assert row is not None + assert row["user_id"] == "alice" + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_get_with_auto_filters_by_actor(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id="alice") + await repo.create("r2", "t1", user_id="bob") + with _as_user("alice"): + assert await repo.get("r1") is not None + assert await repo.get("r2") is None + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_delete_with_wrong_actor_returns_false(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id="alice") + with _as_user("bob"): + assert await repo.delete("r1") is False + assert await repo.get("r1", user_id=None) is not None + finally: + await engine.dispose() + + @pytest.mark.anyio + @pytest.mark.no_auto_user + async def test_auto_user_id_requires_actor_context(self, tmp_path): + engine, repo = await _make_repo(tmp_path) + try: + await repo.create("r1", "t1", user_id="alice") + await repo.create("r2", "t1", user_id="bob") + with pytest.raises(RuntimeError, match="no actor context is set"): + await repo.list_by_thread("t1") + with pytest.raises(RuntimeError, match="no actor context is set"): + await repo.delete("r1") + finally: + await engine.dispose() diff --git a/backend/tests/test_sandbox_audit_middleware.py b/backend/tests/unittest/test_sandbox_audit_middleware.py similarity index 100% rename from backend/tests/test_sandbox_audit_middleware.py rename to backend/tests/unittest/test_sandbox_audit_middleware.py diff --git a/backend/tests/test_sandbox_orphan_reconciliation.py b/backend/tests/unittest/test_sandbox_orphan_reconciliation.py similarity index 100% rename from backend/tests/test_sandbox_orphan_reconciliation.py rename to backend/tests/unittest/test_sandbox_orphan_reconciliation.py diff --git a/backend/tests/test_sandbox_orphan_reconciliation_e2e.py b/backend/tests/unittest/test_sandbox_orphan_reconciliation_e2e.py similarity index 100% rename from backend/tests/test_sandbox_orphan_reconciliation_e2e.py rename to backend/tests/unittest/test_sandbox_orphan_reconciliation_e2e.py diff --git a/backend/tests/test_sandbox_search_tools.py b/backend/tests/unittest/test_sandbox_search_tools.py similarity index 100% rename from backend/tests/test_sandbox_search_tools.py rename to backend/tests/unittest/test_sandbox_search_tools.py diff --git a/backend/tests/test_sandbox_tools_security.py b/backend/tests/unittest/test_sandbox_tools_security.py similarity index 100% rename from backend/tests/test_sandbox_tools_security.py rename to backend/tests/unittest/test_sandbox_tools_security.py diff --git a/backend/tests/test_security_scanner.py b/backend/tests/unittest/test_security_scanner.py similarity index 100% rename from backend/tests/test_security_scanner.py rename to backend/tests/unittest/test_security_scanner.py diff --git a/backend/tests/test_serialization.py b/backend/tests/unittest/test_serialization.py similarity index 100% rename from backend/tests/test_serialization.py rename to backend/tests/unittest/test_serialization.py diff --git a/backend/tests/test_serialize_message_content.py b/backend/tests/unittest/test_serialize_message_content.py similarity index 100% rename from backend/tests/test_serialize_message_content.py rename to backend/tests/unittest/test_serialize_message_content.py diff --git a/backend/tests/test_setup_wizard.py b/backend/tests/unittest/test_setup_wizard.py similarity index 100% rename from backend/tests/test_setup_wizard.py rename to backend/tests/unittest/test_setup_wizard.py diff --git a/backend/tests/test_skill_manage_tool.py b/backend/tests/unittest/test_skill_manage_tool.py similarity index 100% rename from backend/tests/test_skill_manage_tool.py rename to backend/tests/unittest/test_skill_manage_tool.py diff --git a/backend/tests/test_skills_archive_root.py b/backend/tests/unittest/test_skills_archive_root.py similarity index 100% rename from backend/tests/test_skills_archive_root.py rename to backend/tests/unittest/test_skills_archive_root.py diff --git a/backend/tests/test_skills_custom_router.py b/backend/tests/unittest/test_skills_custom_router.py similarity index 100% rename from backend/tests/test_skills_custom_router.py rename to backend/tests/unittest/test_skills_custom_router.py diff --git a/backend/tests/test_skills_installer.py b/backend/tests/unittest/test_skills_installer.py similarity index 100% rename from backend/tests/test_skills_installer.py rename to backend/tests/unittest/test_skills_installer.py diff --git a/backend/tests/test_skills_loader.py b/backend/tests/unittest/test_skills_loader.py similarity index 100% rename from backend/tests/test_skills_loader.py rename to backend/tests/unittest/test_skills_loader.py diff --git a/backend/tests/test_skills_parser.py b/backend/tests/unittest/test_skills_parser.py similarity index 100% rename from backend/tests/test_skills_parser.py rename to backend/tests/unittest/test_skills_parser.py diff --git a/backend/tests/test_skills_validation.py b/backend/tests/unittest/test_skills_validation.py similarity index 100% rename from backend/tests/test_skills_validation.py rename to backend/tests/unittest/test_skills_validation.py diff --git a/backend/tests/test_sse_format.py b/backend/tests/unittest/test_sse_format.py similarity index 93% rename from backend/tests/test_sse_format.py rename to backend/tests/unittest/test_sse_format.py index 5647a22a1..655eb2421 100644 --- a/backend/tests/test_sse_format.py +++ b/backend/tests/unittest/test_sse_format.py @@ -4,7 +4,7 @@ import json def _format_sse(event: str, data, *, event_id: str | None = None) -> str: - from app.gateway.services import format_sse + from app.gateway.routers.langgraph.runs import format_sse return format_sse(event, data, event_id=event_id) diff --git a/backend/tests/test_stream_bridge.py b/backend/tests/unittest/test_stream_bridge.py similarity index 87% rename from backend/tests/test_stream_bridge.py rename to backend/tests/unittest/test_stream_bridge.py index efd5e7923..a6e737725 100644 --- a/backend/tests/test_stream_bridge.py +++ b/backend/tests/unittest/test_stream_bridge.py @@ -6,7 +6,8 @@ import re import anyio import pytest -from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge +from app.infra.stream_bridge import MemoryStreamBridge, build_stream_bridge +from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL # --------------------------------------------------------------------------- # Unit tests for MemoryStreamBridge @@ -20,7 +21,7 @@ def bridge() -> MemoryStreamBridge: @pytest.mark.anyio async def test_publish_subscribe(bridge: MemoryStreamBridge): - """Three events followed by end should be received in order.""" + """Three events followed by a terminal end event should be received in order.""" run_id = "run-1" await bridge.publish(run_id, "metadata", {"run_id": run_id}) @@ -31,21 +32,22 @@ async def test_publish_subscribe(bridge: MemoryStreamBridge): received = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0): received.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break assert len(received) == 4 assert received[0].event == "metadata" assert received[1].event == "values" assert received[2].event == "updates" - assert received[3] is END_SENTINEL + assert received[3].event == "end" + assert received[3].data is None @pytest.mark.anyio async def test_heartbeat(bridge: MemoryStreamBridge): """When no events arrive within the heartbeat interval, yield a heartbeat.""" run_id = "run-heartbeat" - bridge._get_or_create_stream(run_id) # ensure stream exists + await bridge._get_or_create_stream(run_id) # ensure stream exists received = [] @@ -69,12 +71,11 @@ async def test_cleanup(bridge: MemoryStreamBridge): await bridge.cleanup(run_id) assert run_id not in bridge._streams - assert run_id not in bridge._counters @pytest.mark.anyio async def test_history_is_bounded(): - """Retained history should be bounded by queue_maxsize.""" + """Retained history should be bounded by queue_maxsize plus terminal event.""" bridge = MemoryStreamBridge(queue_maxsize=1) run_id = "run-bp" @@ -85,12 +86,12 @@ async def test_history_is_bounded(): received = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0): received.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break assert len(received) == 2 assert received[0].event == "second" - assert received[1] is END_SENTINEL + assert received[1].event == "end" @pytest.mark.anyio @@ -104,13 +105,13 @@ async def test_multiple_runs(bridge: MemoryStreamBridge): events_a = [] async for entry in bridge.subscribe("run-a", heartbeat_interval=1.0): events_a.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break events_b = [] async for entry in bridge.subscribe("run-b", heartbeat_interval=1.0): events_b.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break assert len(events_a) == 2 @@ -132,7 +133,7 @@ async def test_event_id_format(bridge: MemoryStreamBridge): received = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0): received.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break event = received[0] @@ -151,7 +152,7 @@ async def test_subscribe_replays_after_last_event_id(bridge: MemoryStreamBridge) first_pass = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0): first_pass.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break received = [] @@ -161,11 +162,11 @@ async def test_subscribe_replays_after_last_event_id(bridge: MemoryStreamBridge) heartbeat_interval=1.0, ): received.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break assert [entry.event for entry in received[:-1]] == ["values", "updates"] - assert received[-1] is END_SENTINEL + assert received[-1].event == "end" @pytest.mark.anyio @@ -206,11 +207,11 @@ async def test_slow_subscriber_does_not_skip_after_buffer_trim(): heartbeat_interval=1.0, ): received.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break assert [entry.event for entry in received[:-1]] == ["e3"] - assert received[-1] is END_SENTINEL + assert received[-1].event == "end" # --------------------------------------------------------------------------- @@ -220,7 +221,7 @@ async def test_slow_subscriber_does_not_skip_after_buffer_trim(): @pytest.mark.anyio async def test_publish_end_terminates_even_when_history_is_full(): - """publish_end() should terminate subscribers without mutating retained history.""" + """publish_end() should terminate subscribers and append a terminal event.""" bridge = MemoryStreamBridge(queue_maxsize=2) run_id = "run-end-history-full" @@ -230,16 +231,16 @@ async def test_publish_end_terminates_even_when_history_is_full(): assert [entry.event for entry in stream.events] == ["event-1", "event-2"] await bridge.publish_end(run_id) - assert [entry.event for entry in stream.events] == ["event-1", "event-2"] + assert [entry.event for entry in stream.events] == ["event-1", "event-2", "end"] events = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1): events.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break assert [entry.event for entry in events[:-1]] == ["event-1", "event-2"] - assert events[-1] is END_SENTINEL + assert events[-1].event == "end" @pytest.mark.anyio @@ -252,11 +253,11 @@ async def test_publish_end_without_history_yields_end_immediately(): events = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1): events.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break assert len(events) == 1 - assert events[0] is END_SENTINEL + assert events[0].event == "end" @pytest.mark.anyio @@ -272,14 +273,14 @@ async def test_publish_end_preserves_history_when_space_available(): events = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1): events.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": break # All events plus END should be present assert len(events) == 3 assert events[0].event == "event-1" assert events[1].event == "event-2" - assert events[2] is END_SENTINEL + assert events[2].event == "end" @pytest.mark.anyio @@ -301,7 +302,7 @@ async def test_concurrent_tasks_end_sentinel(): events = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1): events.append(entry) - if entry is END_SENTINEL: + if entry.event == "end": return events return events # pragma: no cover @@ -321,7 +322,7 @@ async def test_concurrent_tasks_end_sentinel(): for run_id in run_ids: events = results[run_id] - assert events[-1] is END_SENTINEL, f"Run {run_id} did not receive END sentinel" + assert events[-1].event == "end", f"Run {run_id} did not receive end event" # --------------------------------------------------------------------------- @@ -331,6 +332,6 @@ async def test_concurrent_tasks_end_sentinel(): @pytest.mark.anyio async def test_make_stream_bridge_defaults(): - """make_stream_bridge() with no config yields a MemoryStreamBridge.""" - async with make_stream_bridge() as bridge: + """build_stream_bridge() with no config yields a MemoryStreamBridge.""" + async with build_stream_bridge() as bridge: assert isinstance(bridge, MemoryStreamBridge) diff --git a/backend/tests/test_subagent_executor.py b/backend/tests/unittest/test_subagent_executor.py similarity index 100% rename from backend/tests/test_subagent_executor.py rename to backend/tests/unittest/test_subagent_executor.py diff --git a/backend/tests/test_subagent_limit_middleware.py b/backend/tests/unittest/test_subagent_limit_middleware.py similarity index 100% rename from backend/tests/test_subagent_limit_middleware.py rename to backend/tests/unittest/test_subagent_limit_middleware.py diff --git a/backend/tests/test_subagent_prompt_security.py b/backend/tests/unittest/test_subagent_prompt_security.py similarity index 100% rename from backend/tests/test_subagent_prompt_security.py rename to backend/tests/unittest/test_subagent_prompt_security.py diff --git a/backend/tests/test_subagent_timeout_config.py b/backend/tests/unittest/test_subagent_timeout_config.py similarity index 100% rename from backend/tests/test_subagent_timeout_config.py rename to backend/tests/unittest/test_subagent_timeout_config.py diff --git a/backend/tests/test_suggestions_router.py b/backend/tests/unittest/test_suggestions_router.py similarity index 79% rename from backend/tests/test_suggestions_router.py rename to backend/tests/unittest/test_suggestions_router.py index a8b9b0915..f39de9473 100644 --- a/backend/tests/test_suggestions_router.py +++ b/backend/tests/unittest/test_suggestions_router.py @@ -1,6 +1,7 @@ import asyncio from unittest.mock import AsyncMock, MagicMock +from _router_auth_helpers import call_unwrapped from app.gateway.routers import suggestions @@ -46,9 +47,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch): fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```')) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) - # Bypass the require_permission decorator (which needs request + - # thread_store) — these tests cover the parsing logic. - result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) + result = asyncio.run(call_unwrapped(suggestions.generate_suggestions, "t1", req, request=None)) assert result.suggestions == ["Q1", "Q2", "Q3"] @@ -66,9 +65,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch): fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}])) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) - # Bypass the require_permission decorator (which needs request + - # thread_store) — these tests cover the parsing logic. - result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) + result = asyncio.run(call_unwrapped(suggestions.generate_suggestions, "t1", req, request=None)) assert result.suggestions == ["Q1", "Q2"] @@ -86,9 +83,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch): fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}])) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) - # Bypass the require_permission decorator (which needs request + - # thread_store) — these tests cover the parsing logic. - result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) + result = asyncio.run(call_unwrapped(suggestions.generate_suggestions, "t1", req, request=None)) assert result.suggestions == ["Q1", "Q2"] @@ -103,8 +98,6 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch): fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom")) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) - # Bypass the require_permission decorator (which needs request + - # thread_store) — these tests cover the parsing logic. - result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) + result = asyncio.run(call_unwrapped(suggestions.generate_suggestions, "t1", req, request=None)) assert result.suggestions == [] diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/unittest/test_task_tool_core_logic.py similarity index 100% rename from backend/tests/test_task_tool_core_logic.py rename to backend/tests/unittest/test_task_tool_core_logic.py diff --git a/backend/tests/test_thread_data_middleware.py b/backend/tests/unittest/test_thread_data_middleware.py similarity index 100% rename from backend/tests/test_thread_data_middleware.py rename to backend/tests/unittest/test_thread_data_middleware.py diff --git a/backend/tests/unittest/test_thread_meta_repo.py b/backend/tests/unittest/test_thread_meta_repo.py new file mode 100644 index 000000000..76bc7f53e --- /dev/null +++ b/backend/tests/unittest/test_thread_meta_repo.py @@ -0,0 +1,206 @@ +"""Tests for current thread metadata storage adapters.""" + +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.infra.storage import ThreadMetaStorage, ThreadMetaStoreAdapter +from deerflow.runtime.actor_context import ActorContext, bind_actor_context, reset_actor_context +from store.persistence import MappedBase + + +async def _make_store(tmp_path): + engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}", future=True) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + return engine, ThreadMetaStorage(ThreadMetaStoreAdapter(session_factory)) + + +class TestThreadMetaStorage: + @pytest.mark.anyio + async def test_create_and_get(self, tmp_path): + engine, store = await _make_store(tmp_path) + thread = await store.ensure_thread(thread_id="t1", user_id=None) + assert thread.thread_id == "t1" + assert thread.status == "idle" + fetched = await store.get_thread("t1", user_id=None) + assert fetched is not None + assert fetched.thread_id == "t1" + await engine.dispose() + + @pytest.mark.anyio + async def test_create_with_assistant_id(self, tmp_path): + engine, store = await _make_store(tmp_path) + thread = await store.ensure_thread(thread_id="t1", assistant_id="agent1", user_id=None) + assert thread.assistant_id == "agent1" + await engine.dispose() + + @pytest.mark.anyio + async def test_create_with_owner(self, tmp_path): + engine, store = await _make_store(tmp_path) + token = bind_actor_context(ActorContext(user_id="user1")) + try: + thread = await store.ensure_thread(thread_id="t1") + assert thread.user_id == "user1" + finally: + reset_actor_context(token) + await engine.dispose() + + @pytest.mark.anyio + async def test_create_with_metadata(self, tmp_path): + engine, store = await _make_store(tmp_path) + thread = await store.ensure_thread(thread_id="t1", metadata={"key": "value"}, user_id=None) + assert thread.metadata == {"key": "value"} + await engine.dispose() + + @pytest.mark.anyio + async def test_ensure_thread_is_idempotent(self, tmp_path): + engine, store = await _make_store(tmp_path) + try: + first = await store.ensure_thread(thread_id="t1", user_id=None) + second = await store.ensure_thread(thread_id="t1", user_id=None) + assert second.thread_id == first.thread_id + rows = await store.search_threads(user_id=None) + assert [row.thread_id for row in rows] == ["t1"] + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_get_nonexistent(self, tmp_path): + engine, store = await _make_store(tmp_path) + assert await store.get_thread("missing", user_id=None) is None + await engine.dispose() + + @pytest.mark.anyio + async def test_cross_user_get_is_filtered(self, tmp_path): + engine, store = await _make_store(tmp_path) + token = bind_actor_context(ActorContext(user_id="user1")) + try: + await store.ensure_thread(thread_id="t1") + finally: + reset_actor_context(token) + + token = bind_actor_context(ActorContext(user_id="user2")) + try: + assert await store.get_thread("t1") is None + finally: + reset_actor_context(token) + await engine.dispose() + + @pytest.mark.anyio + async def test_shared_thread_visible_to_anyone_with_explicit_none(self, tmp_path): + engine, store = await _make_store(tmp_path) + await store.ensure_thread(thread_id="t1", user_id=None) + token = bind_actor_context(ActorContext(user_id="user2")) + try: + assert await store.get_thread("t1", user_id=None) is not None + finally: + reset_actor_context(token) + await engine.dispose() + + @pytest.mark.anyio + @pytest.mark.no_auto_user + async def test_auto_user_id_requires_actor_context(self, tmp_path): + engine, store = await _make_store(tmp_path) + try: + await store.ensure_thread(thread_id="t1", user_id="alice") + with pytest.raises(RuntimeError, match="no actor context is set"): + await store.search_threads() + with pytest.raises(RuntimeError, match="no actor context is set"): + await store.get_thread("t1") + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_sync_thread_status(self, tmp_path): + engine, store = await _make_store(tmp_path) + await store.ensure_thread(thread_id="t1", user_id=None) + await store.sync_thread_status(thread_id="t1", status="busy") + thread = await store.get_thread("t1", user_id=None) + assert thread is not None + assert thread.status == "busy" + await engine.dispose() + + @pytest.mark.anyio + async def test_sync_thread_assistant_id(self, tmp_path): + engine, store = await _make_store(tmp_path) + await store.ensure_thread(thread_id="t1", user_id=None) + await store.sync_thread_assistant_id(thread_id="t1", assistant_id="lead_agent") + thread = await store.get_thread("t1", user_id=None) + assert thread is not None + assert thread.assistant_id == "lead_agent" + await engine.dispose() + + @pytest.mark.anyio + async def test_sync_thread_metadata_replaces(self, tmp_path): + engine, store = await _make_store(tmp_path) + await store.ensure_thread(thread_id="t1", metadata={"a": 1}, user_id=None) + await store.sync_thread_metadata(thread_id="t1", metadata={"b": 2}) + thread = await store.get_thread("t1", user_id=None) + assert thread is not None + assert thread.metadata == {"b": 2} + await engine.dispose() + + @pytest.mark.anyio + async def test_delete_thread(self, tmp_path): + engine, store = await _make_store(tmp_path) + await store.ensure_thread(thread_id="t1", user_id=None) + await store.delete_thread("t1") + assert await store.get_thread("t1", user_id=None) is None + await engine.dispose() + + @pytest.mark.anyio + async def test_search_threads_filters_by_actor(self, tmp_path): + engine, store = await _make_store(tmp_path) + token = bind_actor_context(ActorContext(user_id="user1")) + try: + await store.ensure_thread(thread_id="t1") + finally: + reset_actor_context(token) + + token = bind_actor_context(ActorContext(user_id="user2")) + try: + await store.ensure_thread(thread_id="t2") + finally: + reset_actor_context(token) + + token = bind_actor_context(ActorContext(user_id="user1")) + try: + rows = await store.search_threads() + assert [row.thread_id for row in rows] == ["t1"] + finally: + reset_actor_context(token) + await engine.dispose() + + @pytest.mark.anyio + async def test_search_threads_strips_blank_filters(self, tmp_path): + engine, store = await _make_store(tmp_path) + try: + await store.ensure_thread(thread_id="t1", assistant_id="agent1", user_id=None) + rows = await store.search_threads(status=" ", assistant_id=" ", user_id=None) + assert [row.thread_id for row in rows] == ["t1"] + finally: + await engine.dispose() + + @pytest.mark.anyio + async def test_ensure_thread_running_creates_and_updates(self, tmp_path): + engine, store = await _make_store(tmp_path) + try: + created = await store.ensure_thread_running(thread_id="t1", assistant_id="agent1", metadata={"a": 1}) + assert created is not None + assert created.thread_id == "t1" + assert created.status == "running" + + await store.sync_thread_status(thread_id="t1", status="idle") + updated = await store.ensure_thread_running(thread_id="t1") + assert updated is not None + assert updated.status == "running" + finally: + await engine.dispose() diff --git a/backend/tests/test_thread_run_messages_pagination.py b/backend/tests/unittest/test_thread_run_messages_pagination.py similarity index 67% rename from backend/tests/test_thread_run_messages_pagination.py rename to backend/tests/unittest/test_thread_run_messages_pagination.py index f00100cad..31a96b137 100644 --- a/backend/tests/test_thread_run_messages_pagination.py +++ b/backend/tests/unittest/test_thread_run_messages_pagination.py @@ -1,4 +1,4 @@ -"""Tests for paginated GET /api/threads/{thread_id}/runs/{run_id}/messages endpoint.""" +"""Tests for paginated thread-scoped run messages endpoint.""" from __future__ import annotations from unittest.mock import AsyncMock, MagicMock @@ -7,7 +7,7 @@ import pytest from _router_auth_helpers import make_authed_test_app from fastapi.testclient import TestClient -from app.gateway.routers import thread_runs +from app.gateway.routers.langgraph import runs as langgraph_runs # --------------------------------------------------------------------------- @@ -15,13 +15,17 @@ from app.gateway.routers import thread_runs # --------------------------------------------------------------------------- -def _make_app(event_store=None): +def _make_app(event_store=None, run_read_repo=None): """Build a test FastAPI app with stub auth and mocked state.""" app = make_authed_test_app() - app.include_router(thread_runs.router) + app.include_router(langgraph_runs.router, prefix="/api/threads") + app.state.stream_bridge = MagicMock() + app.state.persistence = MagicMock(checkpointer=MagicMock()) if event_store is not None: app.state.run_event_store = event_store + if run_read_repo is not None: + app.state.run_read_repo = run_read_repo return app @@ -33,8 +37,32 @@ def _make_event_store(rows: list[dict]): return store +def _make_run_read_repo(thread_id: str, run_id: str): + repo = MagicMock() + repo.get = AsyncMock( + return_value={ + "run_id": run_id, + "thread_id": thread_id, + "status": "success", + "assistant_id": None, + "metadata": {}, + "multitask_strategy": "reject", + "created_at": "", + "updated_at": "", + } + ) + repo.list_by_thread = AsyncMock(return_value=[]) + return repo + + def _make_message(seq: int) -> dict: - return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"} + return { + "run_id": "run-x", + "seq": seq, + "event_type": "ai_message", + "category": "message", + "content": f"msg-{seq}", + } # --------------------------------------------------------------------------- @@ -45,14 +73,14 @@ def _make_message(seq: int) -> dict: def test_returns_paginated_envelope(): """GET /api/threads/{tid}/runs/{rid}/messages returns {data: [...], has_more: bool}.""" rows = [_make_message(i) for i in range(1, 4)] - app = _make_app(event_store=_make_event_store(rows)) + app = _make_app(event_store=_make_event_store(rows), run_read_repo=_make_run_read_repo("thread-1", "run-1")) with TestClient(app) as client: response = client.get("/api/threads/thread-1/runs/run-1/messages") assert response.status_code == 200 body = response.json() assert "data" in body - assert "has_more" in body - assert body["has_more"] is False + assert "hasMore" in body + assert body["hasMore"] is False assert len(body["data"]) == 3 @@ -60,12 +88,12 @@ def test_has_more_true_when_extra_row_returned(): """has_more=True when event store returns limit+1 rows.""" # Default limit is 50; provide 51 rows rows = [_make_message(i) for i in range(1, 52)] # 51 rows - app = _make_app(event_store=_make_event_store(rows)) + app = _make_app(event_store=_make_event_store(rows), run_read_repo=_make_run_read_repo("thread-2", "run-2")) with TestClient(app) as client: response = client.get("/api/threads/thread-2/runs/run-2/messages") assert response.status_code == 200 body = response.json() - assert body["has_more"] is True + assert body["hasMore"] is True assert len(body["data"]) == 50 # trimmed to limit @@ -73,7 +101,7 @@ def test_after_seq_forwarded_to_event_store(): """after_seq query param is forwarded to event_store.list_messages_by_run.""" rows = [_make_message(10)] event_store = _make_event_store(rows) - app = _make_app(event_store=event_store) + app = _make_app(event_store=event_store, run_read_repo=_make_run_read_repo("thread-3", "run-3")) with TestClient(app) as client: response = client.get("/api/threads/thread-3/runs/run-3/messages?after_seq=5") assert response.status_code == 200 @@ -89,7 +117,7 @@ def test_before_seq_forwarded_to_event_store(): """before_seq query param is forwarded to event_store.list_messages_by_run.""" rows = [_make_message(3)] event_store = _make_event_store(rows) - app = _make_app(event_store=event_store) + app = _make_app(event_store=event_store, run_read_repo=_make_run_read_repo("thread-4", "run-4")) with TestClient(app) as client: response = client.get("/api/threads/thread-4/runs/run-4/messages?before_seq=10") assert response.status_code == 200 @@ -105,7 +133,7 @@ def test_custom_limit_forwarded_to_event_store(): """Custom limit is forwarded as limit+1 to the event store.""" rows = [_make_message(i) for i in range(1, 6)] event_store = _make_event_store(rows) - app = _make_app(event_store=event_store) + app = _make_app(event_store=event_store, run_read_repo=_make_run_read_repo("thread-5", "run-5")) with TestClient(app) as client: response = client.get("/api/threads/thread-5/runs/run-5/messages?limit=10") assert response.status_code == 200 @@ -119,10 +147,10 @@ def test_custom_limit_forwarded_to_event_store(): def test_empty_data_when_no_messages(): """Returns empty data list with has_more=False when no messages exist.""" - app = _make_app(event_store=_make_event_store([])) + app = _make_app(event_store=_make_event_store([]), run_read_repo=_make_run_read_repo("thread-6", "run-6")) with TestClient(app) as client: response = client.get("/api/threads/thread-6/runs/run-6/messages") assert response.status_code == 200 body = response.json() assert body["data"] == [] - assert body["has_more"] is False + assert body["hasMore"] is False diff --git a/backend/tests/test_threads_router.py b/backend/tests/unittest/test_threads_router.py similarity index 53% rename from backend/tests/test_threads_router.py rename to backend/tests/unittest/test_threads_router.py index 4ffa28a8c..ffde1fa85 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/unittest/test_threads_router.py @@ -1,14 +1,19 @@ -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from _router_auth_helpers import make_authed_test_app from fastapi import HTTPException from fastapi.testclient import TestClient -from app.gateway.routers import threads +from app.gateway.routers.langgraph import threads as threads from deerflow.config.paths import Paths +async def _empty_async_iter(): + if False: + yield None + + def test_delete_thread_data_removes_thread_directory(tmp_path): paths = Paths(tmp_path) thread_dir = paths.thread_dir("thread-cleanup") @@ -50,18 +55,17 @@ def test_delete_thread_data_rejects_invalid_thread_id(tmp_path): def test_delete_thread_route_cleans_thread_directory(tmp_path): - from deerflow.runtime.user_context import get_effective_user_id - paths = Paths(tmp_path) - user_id = get_effective_user_id() - thread_dir = paths.thread_dir("thread-route", user_id=user_id) - paths.sandbox_work_dir("thread-route", user_id=user_id).mkdir(parents=True, exist_ok=True) - (paths.sandbox_work_dir("thread-route", user_id=user_id) / "notes.txt").write_text("hello", encoding="utf-8") + thread_dir = paths.thread_dir("thread-route") + paths.sandbox_work_dir("thread-route").mkdir(parents=True, exist_ok=True) + (paths.sandbox_work_dir("thread-route") / "notes.txt").write_text("hello", encoding="utf-8") app = make_authed_test_app() - app.include_router(threads.router) + app.include_router(threads.router, prefix="/api/threads") + app.state.persistence = MagicMock(checkpointer=MagicMock()) + app.state.thread_meta_storage = MagicMock(delete_thread=AsyncMock()) - with patch("app.gateway.routers.threads.get_paths", return_value=paths): + with patch("app.gateway.routers.langgraph.threads.get_paths", return_value=paths): with TestClient(app) as client: response = client.delete("/api/threads/thread-route") @@ -74,9 +78,11 @@ def test_delete_thread_route_rejects_invalid_thread_id(tmp_path): paths = Paths(tmp_path) app = make_authed_test_app() - app.include_router(threads.router) + app.include_router(threads.router, prefix="/api/threads") + app.state.persistence = MagicMock(checkpointer=MagicMock()) + app.state.thread_meta_storage = MagicMock(delete_thread=AsyncMock()) - with patch("app.gateway.routers.threads.get_paths", return_value=paths): + with patch("app.gateway.routers.langgraph.threads.get_paths", return_value=paths): with TestClient(app) as client: response = client.delete("/api/threads/../escape") @@ -87,9 +93,11 @@ def test_delete_thread_route_returns_422_for_route_safe_invalid_id(tmp_path): paths = Paths(tmp_path) app = make_authed_test_app() - app.include_router(threads.router) + app.include_router(threads.router, prefix="/api/threads") + app.state.persistence = MagicMock(checkpointer=MagicMock()) + app.state.thread_meta_storage = MagicMock(delete_thread=AsyncMock()) - with patch("app.gateway.routers.threads.get_paths", return_value=paths): + with patch("app.gateway.routers.langgraph.threads.get_paths", return_value=paths): with TestClient(app) as client: response = client.delete("/api/threads/thread.with.dot") @@ -113,26 +121,38 @@ def test_delete_thread_data_returns_generic_500_error(tmp_path): log_exception.assert_called_once_with("Failed to delete thread data for %s", "thread-cleanup") -# ── Server-reserved metadata key stripping ────────────────────────────────── +def test_get_thread_history_returns_empty_list_when_thread_exists_without_checkpoints(): + app = make_authed_test_app() + app.include_router(threads.router, prefix="/api/threads") + app.state.persistence = MagicMock(checkpointer=MagicMock(alist=lambda *args, **kwargs: _empty_async_iter())) + app.state.thread_meta_storage = MagicMock(get_thread=AsyncMock(return_value=MagicMock(thread_id="thread-empty"))) + app.state.run_store = MagicMock(list_by_thread=AsyncMock(return_value=[])) + + with TestClient(app) as client: + response = client.post("/api/threads/thread-empty/history", json={"limit": 10}) + + assert response.status_code == 200 + assert response.json() == [] -def test_strip_reserved_metadata_removes_user_id(): - """Client-supplied user_id is dropped to prevent reflection attacks.""" - out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"}) - assert out == {"title": "ok"} +def test_get_thread_state_returns_empty_state_when_thread_exists_without_checkpoints(): + app = make_authed_test_app() + app.include_router(threads.router, prefix="/api/threads") + app.state.persistence = MagicMock(checkpointer=MagicMock(aget_tuple=AsyncMock(return_value=None))) + app.state.thread_meta_storage = MagicMock(get_thread=AsyncMock(return_value=MagicMock(thread_id="thread-empty"))) + app.state.run_store = MagicMock(list_by_thread=AsyncMock(return_value=[])) + with TestClient(app) as client: + response = client.get("/api/threads/thread-empty/state") -def test_strip_reserved_metadata_passes_through_safe_keys(): - """Non-reserved keys are preserved verbatim.""" - md = {"title": "ok", "tags": ["a", "b"], "custom": {"x": 1}} - assert threads._strip_reserved_metadata(md) == md - - -def test_strip_reserved_metadata_empty_input(): - """Empty / None metadata returns same object — no crash.""" - assert threads._strip_reserved_metadata({}) == {} - - -def test_strip_reserved_metadata_strips_all_reserved_keys(): - out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"}) - assert out == {"keep": "me"} + assert response.status_code == 200 + assert response.json() == { + "values": {}, + "next": [], + "tasks": [], + "checkpoint": {}, + "checkpoint_id": None, + "parent_checkpoint_id": None, + "metadata": {}, + "created_at": None, + } diff --git a/backend/tests/test_title_generation.py b/backend/tests/unittest/test_title_generation.py similarity index 100% rename from backend/tests/test_title_generation.py rename to backend/tests/unittest/test_title_generation.py diff --git a/backend/tests/test_title_middleware_core_logic.py b/backend/tests/unittest/test_title_middleware_core_logic.py similarity index 100% rename from backend/tests/test_title_middleware_core_logic.py rename to backend/tests/unittest/test_title_middleware_core_logic.py diff --git a/backend/tests/test_todo_middleware.py b/backend/tests/unittest/test_todo_middleware.py similarity index 100% rename from backend/tests/test_todo_middleware.py rename to backend/tests/unittest/test_todo_middleware.py diff --git a/backend/tests/test_token_usage.py b/backend/tests/unittest/test_token_usage.py similarity index 100% rename from backend/tests/test_token_usage.py rename to backend/tests/unittest/test_token_usage.py diff --git a/backend/tests/test_tool_error_handling_middleware.py b/backend/tests/unittest/test_tool_error_handling_middleware.py similarity index 100% rename from backend/tests/test_tool_error_handling_middleware.py rename to backend/tests/unittest/test_tool_error_handling_middleware.py diff --git a/backend/tests/test_tool_output_truncation.py b/backend/tests/unittest/test_tool_output_truncation.py similarity index 100% rename from backend/tests/test_tool_output_truncation.py rename to backend/tests/unittest/test_tool_output_truncation.py diff --git a/backend/tests/test_tool_search.py b/backend/tests/unittest/test_tool_search.py similarity index 100% rename from backend/tests/test_tool_search.py rename to backend/tests/unittest/test_tool_search.py diff --git a/backend/tests/test_tracing_config.py b/backend/tests/unittest/test_tracing_config.py similarity index 100% rename from backend/tests/test_tracing_config.py rename to backend/tests/unittest/test_tracing_config.py diff --git a/backend/tests/test_tracing_factory.py b/backend/tests/unittest/test_tracing_factory.py similarity index 100% rename from backend/tests/test_tracing_factory.py rename to backend/tests/unittest/test_tracing_factory.py diff --git a/backend/tests/test_uploads_manager.py b/backend/tests/unittest/test_uploads_manager.py similarity index 100% rename from backend/tests/test_uploads_manager.py rename to backend/tests/unittest/test_uploads_manager.py diff --git a/backend/tests/test_uploads_middleware_core_logic.py b/backend/tests/unittest/test_uploads_middleware_core_logic.py similarity index 99% rename from backend/tests/test_uploads_middleware_core_logic.py rename to backend/tests/unittest/test_uploads_middleware_core_logic.py index 2c562b179..b0b8fce7e 100644 --- a/backend/tests/test_uploads_middleware_core_logic.py +++ b/backend/tests/unittest/test_uploads_middleware_core_logic.py @@ -34,7 +34,7 @@ def _runtime(thread_id: str | None = THREAD_ID) -> MagicMock: def _uploads_dir(tmp_path: Path, thread_id: str = THREAD_ID) -> Path: - from deerflow.runtime.user_context import get_effective_user_id + from deerflow.runtime.actor_context import get_effective_user_id d = Paths(str(tmp_path)).sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) d.mkdir(parents=True, exist_ok=True) diff --git a/backend/tests/test_uploads_router.py b/backend/tests/unittest/test_uploads_router.py similarity index 100% rename from backend/tests/test_uploads_router.py rename to backend/tests/unittest/test_uploads_router.py diff --git a/backend/tests/unittest/test_user_context.py b/backend/tests/unittest/test_user_context.py new file mode 100644 index 000000000..e5b7ab1b4 --- /dev/null +++ b/backend/tests/unittest/test_user_context.py @@ -0,0 +1,97 @@ +"""Tests for the runtime actor-context bridge.""" + +from types import SimpleNamespace + +import pytest + +from deerflow.runtime.actor_context import ( + ActorContext, + DEFAULT_USER_ID, + get_actor_context, + get_effective_user_id, + require_actor_context, + reset_actor_context, + bind_actor_context, +) + + +@pytest.mark.no_auto_user +def test_default_is_none(): + """Before any set, contextvar returns None.""" + assert get_actor_context() is None + + +@pytest.mark.no_auto_user +def test_set_and_reset_roundtrip(): + """Binding returns a token that reset restores.""" + actor = ActorContext(user_id="user-1") + token = bind_actor_context(actor) + try: + assert get_actor_context() == actor + finally: + reset_actor_context(token) + assert get_actor_context() is None + + +@pytest.mark.no_auto_user +def test_require_current_user_raises_when_unset(): + """require_actor_context raises RuntimeError if no actor is bound.""" + assert get_actor_context() is None + with pytest.raises(RuntimeError, match="without actor context"): + require_actor_context() + + +@pytest.mark.no_auto_user +def test_require_current_user_returns_user_when_set(): + """require_actor_context returns the bound actor.""" + actor = ActorContext(user_id="user-2") + token = bind_actor_context(actor) + try: + assert require_actor_context() == actor + finally: + reset_actor_context(token) + + +@pytest.mark.no_auto_user +def test_protocol_accepts_duck_typed(): + actor = ActorContext(user_id="user-3") + assert actor.user_id == "user-3" + + +# --------------------------------------------------------------------------- +# get_effective_user_id / DEFAULT_USER_ID tests +# --------------------------------------------------------------------------- + + +def test_default_user_id_is_default(): + assert DEFAULT_USER_ID == "default" + + +@pytest.mark.no_auto_user +def test_effective_user_id_returns_default_when_no_user(): + """No user in context -> fallback to DEFAULT_USER_ID.""" + assert get_effective_user_id() == "default" + + +@pytest.mark.no_auto_user +def test_effective_user_id_returns_user_id_when_set(): + actor = ActorContext(user_id="u-abc-123") + token = bind_actor_context(actor) + try: + assert get_effective_user_id() == "u-abc-123" + finally: + reset_actor_context(token) + + +@pytest.mark.no_auto_user +def test_effective_user_id_coerces_to_str(): + """User.id might be a UUID object; must come back as str.""" + import uuid + uid = uuid.uuid4() + + actor = ActorContext(user_id=str(uid)) + token = bind_actor_context(actor) + try: + assert get_effective_user_id() == str(uid) + finally: + reset_actor_context(token) diff --git a/backend/tests/unittest/test_user_repository.py b/backend/tests/unittest/test_user_repository.py new file mode 100644 index 000000000..125e676eb --- /dev/null +++ b/backend/tests/unittest/test_user_repository.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.plugins.auth.storage import DbUserRepository, UserCreate +from app.plugins.auth.storage.contracts import User +from app.plugins.auth.storage.models import User as UserModel # noqa: F401 +from store.persistence import MappedBase + + +async def _make_repo(tmp_path): + engine = create_async_engine( + f"sqlite+aiosqlite:///{tmp_path / 'users.db'}", + future=True, + ) + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + return engine, session_factory + + +class TestUserRepository: + @pytest.mark.anyio + async def test_create_and_get_by_id(self, tmp_path): + engine, session_factory = await _make_repo(tmp_path) + async with session_factory() as session: + repo = DbUserRepository(session) + created = await repo.create_user( + UserCreate( + id="user-1", + email="user1@example.com", + password_hash="hash-1", + ) + ) + await session.commit() + fetched = await repo.get_user_by_id("user-1") + await engine.dispose() + + assert created.id == "user-1" + assert fetched is not None + assert fetched.email == "user1@example.com" + assert fetched.password_hash == "hash-1" + assert fetched.system_role == "user" + assert fetched.needs_setup is False + assert fetched.token_version == 0 + + @pytest.mark.anyio + async def test_get_by_email_and_oauth(self, tmp_path): + engine, session_factory = await _make_repo(tmp_path) + async with session_factory() as session: + repo = DbUserRepository(session) + await repo.create_user( + UserCreate( + id="user-2", + email="oauth@example.com", + oauth_provider="github", + oauth_id="gh-123", + ) + ) + await session.commit() + by_email = await repo.get_user_by_email("oauth@example.com") + by_oauth = await repo.get_user_by_oauth("github", "gh-123") + await engine.dispose() + + assert by_email is not None + assert by_email.id == "user-2" + assert by_oauth is not None + assert by_oauth.email == "oauth@example.com" + + @pytest.mark.anyio + async def test_update_user(self, tmp_path): + engine, session_factory = await _make_repo(tmp_path) + async with session_factory() as session: + repo = DbUserRepository(session) + created = await repo.create_user( + UserCreate( + id="user-3", + email="before@example.com", + password_hash="old-hash", + needs_setup=True, + ) + ) + updated = await repo.update_user( + User( + id=created.id, + email="after@example.com", + password_hash="new-hash", + system_role="admin", + oauth_provider=None, + oauth_id=None, + needs_setup=False, + token_version=2, + created_time=created.created_time, + updated_time=created.updated_time, + ) + ) + await session.commit() + fetched = await repo.get_user_by_id("user-3") + await engine.dispose() + + assert updated.email == "after@example.com" + assert fetched is not None + assert fetched.system_role == "admin" + assert fetched.password_hash == "new-hash" + assert fetched.needs_setup is False + assert fetched.token_version == 2 + + @pytest.mark.anyio + async def test_count_users_and_admins(self, tmp_path): + engine, session_factory = await _make_repo(tmp_path) + async with session_factory() as session: + repo = DbUserRepository(session) + await repo.create_user(UserCreate(id="user-4", email="admin@example.com", system_role="admin")) + await repo.create_user(UserCreate(id="user-5", email="user@example.com", system_role="user")) + await session.commit() + user_count = await repo.count_users() + admin_count = await repo.count_admin_users() + await engine.dispose() + + assert user_count == 2 + assert admin_count == 1 + + @pytest.mark.anyio + async def test_duplicate_email_raises_value_error(self, tmp_path): + engine, session_factory = await _make_repo(tmp_path) + async with session_factory() as session: + repo = DbUserRepository(session) + await repo.create_user(UserCreate(id="user-6", email="dup@example.com")) + with pytest.raises(ValueError, match="User already exists"): + await repo.create_user(UserCreate(id="user-7", email="dup@example.com")) + await engine.dispose() + + @pytest.mark.anyio + async def test_update_missing_user_raises_lookup_error(self, tmp_path): + engine, session_factory = await _make_repo(tmp_path) + async with session_factory() as session: + repo = DbUserRepository(session) + with pytest.raises(LookupError, match="not found"): + await repo.update_user( + User( + id="missing-user", + email="missing@example.com", + password_hash=None, + system_role="user", + oauth_provider=None, + oauth_id=None, + needs_setup=False, + token_version=0, + created_time=datetime.now(UTC), + updated_time=None, + ) + ) + await engine.dispose() diff --git a/backend/tests/test_vllm_provider.py b/backend/tests/unittest/test_vllm_provider.py similarity index 100% rename from backend/tests/test_vllm_provider.py rename to backend/tests/unittest/test_vllm_provider.py diff --git a/backend/tests/test_wechat_channel.py b/backend/tests/unittest/test_wechat_channel.py similarity index 100% rename from backend/tests/test_wechat_channel.py rename to backend/tests/unittest/test_wechat_channel.py