refactor(tests): reorganize tests into unittest/ and e2e/ directories

- Move all unit tests from tests/ to tests/unittest/
- Add tests/e2e/ directory for end-to-end tests
- Update conftest.py for new test structure
- Add new tests for auth dependencies, policies, route injection
- Add new tests for run callbacks, create store, execution artifacts
- Remove obsolete tests for deleted persistence layer

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-22 11:24:53 +08:00
parent 38a6ec496f
commit 2fe0856e33
149 changed files with 3450 additions and 4664 deletions

View File

@ -1,4 +1,4 @@
"""Test configuration for the backend test suite.
"""Test configuration shared by unit and end-to-end tests.
Sets up sys.path and pre-mocks modules that would cause circular import
issues when unit-testing lightweight config/registry code in isolation.
@ -6,8 +6,8 @@ issues when unit-testing lightweight config/registry code in isolation.
import importlib.util
import sys
import warnings
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@ -38,14 +38,51 @@ _executor_mock.get_background_task_result = MagicMock()
sys.modules["deerflow.subagents.executor"] = _executor_mock
def pytest_configure(config):
config.addinivalue_line(
"markers",
"no_auto_user: disable the conftest autouse contextvar fixture for this test",
)
warnings.filterwarnings(
"ignore",
message=r"Pydantic serializer warnings:.*field_name='context'.*",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message=r"pkg_resources is deprecated as an API.*",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message=r"Deprecated call to `pkg_resources\.declare_namespace\(.*",
category=DeprecationWarning,
)
warnings.filterwarnings(
"ignore",
message=r"datetime\.datetime\.utcfromtimestamp\(\) is deprecated.*",
category=DeprecationWarning,
)
warnings.filterwarnings(
"ignore",
message=r"websockets\.InvalidStatusCode is deprecated",
category=DeprecationWarning,
)
warnings.filterwarnings(
"ignore",
message=r"websockets\.legacy is deprecated.*",
category=DeprecationWarning,
)
warnings.filterwarnings(
"ignore",
message=r"websockets\.client\.WebSocketClientProtocol is deprecated",
category=DeprecationWarning,
)
@pytest.fixture()
def provisioner_module():
"""Load docker/provisioner/app.py as an importable test module.
Shared by test_provisioner_kubeconfig and test_provisioner_pvc_volumes so
that any change to the provisioner entry-point path or module name only
needs to be updated in one place.
"""
"""Load docker/provisioner/app.py as an importable test module."""
repo_root = Path(__file__).resolve().parents[2]
module_path = repo_root / "docker" / "provisioner" / "app.py"
spec = importlib.util.spec_from_file_location("provisioner_app_test", module_path)
@ -56,42 +93,25 @@ def provisioner_module():
return module
# ---------------------------------------------------------------------------
# Auto-set user context for every test unless marked no_auto_user
# ---------------------------------------------------------------------------
#
# Repository methods read ``user_id`` from a contextvar by default
# (see ``deerflow.runtime.user_context``). Without this fixture, every
# pre-existing persistence test would raise RuntimeError because the
# contextvar is unset. The fixture sets a default test user on every
# test; tests that explicitly want to verify behaviour *without* a user
# context should mark themselves ``@pytest.mark.no_auto_user``.
@pytest.fixture(autouse=True)
def _auto_user_context(request):
"""Inject a default ``test-user-autouse`` into the contextvar.
Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that
tests which don't touch the persistence layer never pay the cost
of importing runtime.user_context.
"""
"""Inject a default ``test-user-autouse`` into the contextvar."""
if request.node.get_closest_marker("no_auto_user"):
yield
return
try:
from deerflow.runtime.user_context import (
reset_current_user,
set_current_user,
from deerflow.runtime.actor_context import (
ActorContext,
bind_actor_context,
reset_actor_context,
)
except ImportError:
yield
return
user = SimpleNamespace(id="test-user-autouse", email="test@local")
token = set_current_user(user)
token = bind_actor_context(ActorContext(user_id="test-user-autouse"))
try:
yield
finally:
reset_current_user(token)
reset_actor_context(token)

View File

@ -0,0 +1,33 @@
"""Shared fixtures for end-to-end API tests."""
import pytest
from fastapi.testclient import TestClient
from app.plugins.auth.api.schemas import _login_attempts
from app.plugins.auth.domain.config import AuthConfig
from app.plugins.auth.runtime.config_state import reset_auth_config, set_auth_config
from store.config.app_config import AppConfig, reset_app_config, set_app_config
from store.config.storage_config import StorageConfig
_TEST_SECRET = "test-secret-key-e2e-auth-minimum-32"
@pytest.fixture()
def client(tmp_path):
"""Create a full app client backed by an isolated SQLite directory."""
from app.gateway.app import create_app
_login_attempts.clear()
reset_auth_config()
reset_app_config()
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
set_app_config(AppConfig(storage=StorageConfig(driver="sqlite", sqlite_dir=str(tmp_path))))
app = create_app()
with TestClient(app) as test_client:
yield test_client
_login_attempts.clear()
reset_auth_config()
reset_app_config()

View File

@ -0,0 +1,163 @@
"""End-to-end auth API tests for the main auth user journeys."""
from app.plugins.auth.security.csrf import CSRF_HEADER_NAME
def _initialize_payload(**overrides):
return {
"email": "admin@example.com",
"password": "Str0ng!Pass99",
**overrides,
}
def _register_payload(**overrides):
return {
"email": "user@example.com",
"password": "Str0ng!Pass99",
**overrides,
}
def _login(client, *, email="user@example.com", password="Str0ng!Pass99"):
return client.post(
"/api/v1/auth/login/local",
data={"username": email, "password": password},
)
def _csrf_headers(client) -> dict[str, str]:
token = client.cookies.get("csrf_token")
assert token, "csrf_token cookie is required before calling protected POST endpoints"
return {CSRF_HEADER_NAME: token}
def test_initialize_returns_admin_and_sets_session_cookie(client):
response = client.post("/api/v1/auth/initialize", json=_initialize_payload())
assert response.status_code == 201
assert response.json()["email"] == "admin@example.com"
assert response.json()["system_role"] == "admin"
assert "access_token" in response.cookies
assert "access_token" in client.cookies
def test_me_returns_initialized_admin_identity(client):
initialize = client.post("/api/v1/auth/initialize", json=_initialize_payload())
assert initialize.status_code == 201
response = client.get("/api/v1/auth/me")
assert response.status_code == 200
assert response.json() == {
"id": response.json()["id"],
"email": "admin@example.com",
"system_role": "admin",
"needs_setup": False,
}
def test_setup_status_flips_after_initialize(client):
before = client.get("/api/v1/auth/setup-status")
assert before.status_code == 200
assert before.json() == {"needs_setup": True}
initialize = client.post("/api/v1/auth/initialize", json=_initialize_payload())
assert initialize.status_code == 201
after = client.get("/api/v1/auth/setup-status")
assert after.status_code == 200
assert after.json() == {"needs_setup": False}
def test_register_logs_in_user_and_me_returns_identity(client):
response = client.post("/api/v1/auth/register", json=_register_payload())
assert response.status_code == 201
assert response.json()["email"] == "user@example.com"
assert response.json()["system_role"] == "user"
assert "access_token" in client.cookies
assert "csrf_token" in client.cookies
me = client.get("/api/v1/auth/me")
assert me.status_code == 200
assert me.json()["email"] == "user@example.com"
assert me.json()["system_role"] == "user"
assert me.json()["needs_setup"] is False
def test_me_requires_authentication(client):
response = client.get("/api/v1/auth/me")
assert response.status_code == 401
assert response.json()["detail"]["code"] == "not_authenticated"
def test_logout_clears_session_and_me_is_denied(client):
register = client.post("/api/v1/auth/register", json=_register_payload())
assert register.status_code == 201
logout = client.post("/api/v1/auth/logout")
assert logout.status_code == 200
assert logout.json() == {"message": "Successfully logged out"}
me = client.get("/api/v1/auth/me")
assert me.status_code == 401
assert me.json()["detail"]["code"] == "not_authenticated"
def test_login_local_restores_session_after_logout(client):
register = client.post("/api/v1/auth/register", json=_register_payload())
assert register.status_code == 201
assert client.post("/api/v1/auth/logout").status_code == 200
login = _login(client)
assert login.status_code == 200
assert login.json()["needs_setup"] is False
assert "access_token" in client.cookies
assert "csrf_token" in client.cookies
me = client.get("/api/v1/auth/me")
assert me.status_code == 200
assert me.json()["email"] == "user@example.com"
def test_change_password_updates_credentials_and_rotates_login(client):
register = client.post("/api/v1/auth/register", json=_register_payload())
assert register.status_code == 201
change = client.post(
"/api/v1/auth/change-password",
json={
"current_password": "Str0ng!Pass99",
"new_password": "An0ther!Pass88",
"new_email": "renamed@example.com",
},
headers=_csrf_headers(client),
)
assert change.status_code == 200
assert change.json() == {"message": "Password changed successfully"}
assert client.post("/api/v1/auth/logout").status_code == 200
old_login = _login(client)
assert old_login.status_code == 401
assert old_login.json()["detail"]["code"] == "invalid_credentials"
new_login = _login(client, email="renamed@example.com", password="An0ther!Pass88")
assert new_login.status_code == 200
me = client.get("/api/v1/auth/me")
assert me.status_code == 200
assert me.json()["email"] == "renamed@example.com"
def test_oauth_endpoints_expose_current_placeholder_behavior(client):
unsupported = client.get("/api/v1/auth/oauth/not-a-provider")
assert unsupported.status_code == 400
github = client.get("/api/v1/auth/oauth/github")
assert github.status_code == 501
callback = client.get("/api/v1/auth/callback/github", params={"code": "abc", "state": "xyz"})
assert callback.status_code == 501

View File

@ -1,304 +0,0 @@
"""Unit tests for checkpointer config and singleton factory."""
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import deerflow.config.app_config as app_config_module
from deerflow.config.checkpointer_config import (
CheckpointerConfig,
get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
)
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
@pytest.fixture(autouse=True)
def reset_state():
"""Reset singleton state before each test."""
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
yield
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer()
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
class TestCheckpointerConfig:
def test_load_memory_config(self):
load_checkpointer_config_from_dict({"type": "memory"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "memory"
assert config.connection_string is None
def test_load_sqlite_config(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "sqlite"
assert config.connection_string == "/tmp/test.db"
def test_load_postgres_config(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "postgres"
assert config.connection_string == "postgresql://localhost/db"
def test_default_connection_string_is_none(self):
config = CheckpointerConfig(type="memory")
assert config.connection_string is None
def test_set_config_to_none(self):
load_checkpointer_config_from_dict({"type": "memory"})
set_checkpointer_config(None)
assert get_checkpointer_config() is None
def test_invalid_type_raises(self):
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
# ---------------------------------------------------------------------------
# Factory tests
# ---------------------------------------------------------------------------
class TestGetCheckpointer:
def test_returns_in_memory_saver_when_not_configured(self):
"""get_checkpointer should return InMemorySaver when not configured."""
from langgraph.checkpoint.memory import InMemorySaver
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_memory_returns_in_memory_saver(self):
load_checkpointer_config_from_dict({"type": "memory"})
from langgraph.checkpoint.memory import InMemorySaver
cp = get_checkpointer()
assert isinstance(cp, InMemorySaver)
def test_memory_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
cp2 = get_checkpointer()
assert cp1 is cp2
def test_reset_clears_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
cp1 = get_checkpointer()
reset_checkpointer()
cp2 = get_checkpointer()
assert cp1 is not cp2
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
get_checkpointer()
def test_postgres_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
get_checkpointer()
def test_postgres_raises_when_connection_string_missing(self):
load_checkpointer_config_from_dict({"type": "postgres"})
mock_saver = MagicMock()
mock_module = MagicMock()
mock_module.PostgresSaver = mock_saver
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
reset_checkpointer()
with pytest.raises(ValueError, match="connection_string is required"):
get_checkpointer()
def test_sqlite_creates_saver(self):
"""SQLite checkpointer is created when package is available."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once()
mock_saver_instance.setup.assert_called_once()
def test_postgres_creates_saver(self):
"""Postgres checkpointer is created when packages are available."""
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_pg_module = MagicMock()
mock_pg_module.PostgresSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
mock_saver_instance.setup.assert_called_once()
class TestAsyncCheckpointer:
@pytest.mark.anyio
async def test_sqlite_creates_parent_dir_via_to_thread(self):
"""Async SQLite setup should move mkdir off the event loop."""
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
mock_config = MagicMock()
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
mock_saver = AsyncMock()
mock_cm = AsyncMock()
mock_cm.__aenter__.return_value = mock_saver
mock_cm.__aexit__.return_value = False
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string.return_value = mock_cm
mock_module = MagicMock()
mock_module.AsyncSqliteSaver = mock_saver_cls
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
patch(
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
return_value="/tmp/resolved/test.db",
),
):
async with make_checkpointer() as saver:
assert saver is mock_saver
mock_to_thread.assert_awaited_once()
called_fn, called_path = mock_to_thread.await_args.args
assert called_fn.__name__ == "ensure_sqlite_parent_dir"
assert called_path == "/tmp/resolved/test.db"
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
mock_saver.setup.assert_awaited_once()
# ---------------------------------------------------------------------------
# app_config.py integration
# ---------------------------------------------------------------------------
class TestAppConfigLoadsCheckpointer:
def test_load_checkpointer_section(self):
"""load_checkpointer_config_from_dict populates the global config."""
set_checkpointer_config(None)
load_checkpointer_config_from_dict({"type": "memory"})
cfg = get_checkpointer_config()
assert cfg is not None
assert cfg.type == "memory"
# ---------------------------------------------------------------------------
# DeerFlowClient falls back to config checkpointer
# ---------------------------------------------------------------------------
class TestClientCheckpointerFallback:
def test_client_uses_config_checkpointer_when_none_provided(self):
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=None)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert "checkpointer" in captured_kwargs
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
def test_client_explicit_checkpointer_takes_precedence(self):
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
explicit_cp = MagicMock()
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=explicit_cp)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert captured_kwargs["checkpointer"] is explicit_cp

View File

@ -1,55 +0,0 @@
"""Test for issue #1016: checkpointer should not return None."""
from unittest.mock import MagicMock, patch
import pytest
from langgraph.checkpoint.memory import InMemorySaver
class TestCheckpointerNoneFix:
"""Tests that checkpointer context managers return InMemorySaver instead of None."""
@pytest.mark.anyio
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
# Mock get_app_config to return a config with checkpointer=None and database=None
mock_config = MagicMock()
mock_config.checkpointer = None
mock_config.database = None
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
async with make_checkpointer() as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
assert isinstance(checkpointer, InMemorySaver)
# Should be able to call alist() without AttributeError
# This is what LangGraph does and what was failing in issue #1016
result = []
async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}):
result.append(item)
# Empty list is expected for a fresh checkpointer
assert result == []
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
from deerflow.runtime.checkpointer.provider import checkpointer_context
# Mock get_app_config to return a config with checkpointer=None
mock_config = MagicMock()
mock_config.checkpointer = None
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
with checkpointer_context() as checkpointer:
# Should return InMemorySaver, not None
assert checkpointer is not None
assert isinstance(checkpointer, InMemorySaver)
# Should be able to call list() without AttributeError
result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}}))
# Empty list is expected for a fresh checkpointer
assert result == []

View File

@ -1,296 +0,0 @@
"""Tests for _ensure_admin_user() in app.py.
Covers: first-boot no-op (admin creation removed), orphan migration
when admin exists, no-op on no admin found, and edge cases.
"""
import asyncio
import os
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
from app.gateway.auth.config import AuthConfig, set_auth_config
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
@pytest.fixture(autouse=True)
def _setup_auth_config():
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
yield
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
def _make_app_stub(store=None):
"""Minimal app-like object with state.store."""
app = SimpleNamespace()
app.state = SimpleNamespace()
app.state.store = store
return app
def _make_provider(admin_count=0):
p = AsyncMock()
p.count_users = AsyncMock(return_value=admin_count)
p.count_admin_users = AsyncMock(return_value=admin_count)
p.create_user = AsyncMock()
p.update_user = AsyncMock(side_effect=lambda u: u)
return p
def _make_session_factory(admin_row=None):
"""Build a mock async session factory that returns a row from execute()."""
row_result = MagicMock()
row_result.scalar_one_or_none.return_value = admin_row
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = admin_row
session = AsyncMock()
session.execute = AsyncMock(return_value=execute_result)
# Async context manager
session_cm = AsyncMock()
session_cm.__aenter__ = AsyncMock(return_value=session)
session_cm.__aexit__ = AsyncMock(return_value=False)
sf = MagicMock()
sf.return_value = session_cm
return sf
# ── First boot: no admin → return early ──────────────────────────────────
def test_first_boot_does_not_create_admin():
"""admin_count==0 → do NOT create admin automatically."""
provider = _make_provider(admin_count=0)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_not_called()
def test_first_boot_skips_migration():
"""No admin → return early before any migration attempt."""
provider = _make_provider(admin_count=0)
store = AsyncMock()
store.asearch = AsyncMock(return_value=[])
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
store.asearch.assert_not_called()
# ── Admin exists: migration runs when admin row found ────────────────────
def test_admin_exists_triggers_migration():
"""Admin exists and admin row found → _migrate_orphaned_threads called."""
from uuid import uuid4
admin_row = MagicMock()
admin_row.id = uuid4()
provider = _make_provider(admin_count=1)
sf = _make_session_factory(admin_row=admin_row)
store = AsyncMock()
store.asearch = AsyncMock(return_value=[])
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
store.asearch.assert_called_once()
def test_admin_exists_no_admin_row_skips_migration():
"""Admin count > 0 but DB row missing (edge case) → skip migration gracefully."""
provider = _make_provider(admin_count=2)
sf = _make_session_factory(admin_row=None)
store = AsyncMock()
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
store.asearch.assert_not_called()
def test_admin_exists_no_store_skips_migration():
"""Admin exists, row found, but no store → no crash, no migration."""
from uuid import uuid4
admin_row = MagicMock()
admin_row.id = uuid4()
provider = _make_provider(admin_count=1)
sf = _make_session_factory(admin_row=admin_row)
app = _make_app_stub(store=None)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
# No assertion needed — just verify no crash
def test_admin_exists_session_factory_none_skips_migration():
"""get_session_factory() returns None → return early, no crash."""
provider = _make_provider(admin_count=1)
store = AsyncMock()
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("deerflow.persistence.engine.get_session_factory", return_value=None):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
store.asearch.assert_not_called()
def test_migration_failure_is_non_fatal():
"""_migrate_orphaned_threads exception is caught and logged."""
from uuid import uuid4
admin_row = MagicMock()
admin_row.id = uuid4()
provider = _make_provider(admin_count=1)
sf = _make_session_factory(admin_row=admin_row)
store = AsyncMock()
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
from app.gateway.app import _ensure_admin_user
# Should not raise
asyncio.run(_ensure_admin_user(app))
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
"""First boot finds Store-only legacy threads → stamps admin's id.
Validates the **TC-UPG-02 upgrade story**: an operator running main
(no auth) accumulates threads in the LangGraph Store namespace
``("threads",)`` with no ``metadata.user_id``. After upgrading to
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
rewrite each unowned item with the freshly created admin's id.
"""
from app.gateway.app import _migrate_orphaned_threads
# Three orphan items + one already-owned item that should be left alone.
items = [
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
SimpleNamespace(key="t3", value={"metadata": {}}),
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
]
store = AsyncMock()
# asearch returns the entire batch on first call, then an empty page
# to terminate _iter_store_items.
store.asearch = AsyncMock(side_effect=[items, []])
aput_calls: list[tuple[tuple, str, dict]] = []
async def _record_aput(namespace, key, value):
aput_calls.append((namespace, key, value))
store.aput = AsyncMock(side_effect=_record_aput)
migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42"))
# Three orphan rows migrated, one preserved.
assert migrated == 3
assert len(aput_calls) == 3
rewritten_keys = {call[1] for call in aput_calls}
assert rewritten_keys == {"t1", "t2", "t3"}
# Each rewrite carries the new user_id; titles preserved where present.
by_key = {call[1]: call[2] for call in aput_calls}
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
# The pre-owned item must NOT have been rewritten.
assert "t4" not in rewritten_keys
def test_migrate_orphaned_threads_empty_store_is_noop():
"""A store with no threads → migrated == 0, no aput calls."""
from app.gateway.app import _migrate_orphaned_threads
store = AsyncMock()
store.asearch = AsyncMock(return_value=[])
store.aput = AsyncMock()
migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42"))
assert migrated == 0
store.aput.assert_not_called()
def test_iter_store_items_walks_multiple_pages():
"""Cursor-style iterator pulls every page until a short page terminates.
Closes the regression where the old hardcoded ``limit=1000`` could
silently drop orphans on a large pre-upgrade dataset. The migration
code path uses the default ``page_size=500``; this test pins the
iterator with ``page_size=2`` so it stays fast.
"""
from app.gateway.app import _iter_store_items
page_a = [SimpleNamespace(key=f"t{i}", value={"metadata": {}}) for i in range(2)]
page_b = [SimpleNamespace(key=f"t{i + 2}", value={"metadata": {}}) for i in range(2)]
page_c: list = [] # short page → loop terminates
store = AsyncMock()
store.asearch = AsyncMock(side_effect=[page_a, page_b, page_c])
async def _collect():
return [item.key async for item in _iter_store_items(store, ("threads",), page_size=2)]
keys = asyncio.run(_collect())
assert keys == ["t0", "t1", "t2", "t3"]
# Three asearch calls: full batch, full batch, empty terminator
assert store.asearch.await_count == 3
def test_iter_store_items_terminates_on_short_page():
"""A short page (len < page_size) ends the loop without an extra call."""
from app.gateway.app import _iter_store_items
page = [SimpleNamespace(key=f"t{i}", value={}) for i in range(3)]
store = AsyncMock()
store.asearch = AsyncMock(return_value=page)
async def _collect():
return [item.key async for item in _iter_store_items(store, ("threads",), page_size=10)]
keys = asyncio.run(_collect())
assert keys == ["t0", "t1", "t2"]
# Only one call — no terminator probe needed because len(batch) < page_size
assert store.asearch.await_count == 1

View File

@ -1,289 +0,0 @@
"""Tests for FeedbackRepository and follow-up association.
Uses temp SQLite DB for ORM tests.
"""
import pytest
from deerflow.persistence.feedback import FeedbackRepository
async def _make_feedback_repo(tmp_path):
from deerflow.persistence.engine import get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return FeedbackRepository(get_session_factory())
async def _cleanup():
from deerflow.persistence.engine import close_engine
await close_engine()
# -- FeedbackRepository --
class TestFeedbackRepository:
@pytest.mark.anyio
async def test_create_positive(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1)
assert record["feedback_id"]
assert record["rating"] == 1
assert record["run_id"] == "r1"
assert record["thread_id"] == "t1"
assert "created_at" in record
await _cleanup()
@pytest.mark.anyio
async def test_create_negative_with_comment(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.create(
run_id="r1",
thread_id="t1",
rating=-1,
comment="Response was inaccurate",
)
assert record["rating"] == -1
assert record["comment"] == "Response was inaccurate"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_message_id(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42")
assert record["message_id"] == "msg-42"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_owner(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
assert record["user_id"] == "user-1"
await _cleanup()
@pytest.mark.anyio
async def test_create_invalid_rating_zero(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
with pytest.raises(ValueError):
await repo.create(run_id="r1", thread_id="t1", rating=0)
await _cleanup()
@pytest.mark.anyio
async def test_create_invalid_rating_five(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
with pytest.raises(ValueError):
await repo.create(run_id="r1", thread_id="t1", rating=5)
await _cleanup()
@pytest.mark.anyio
async def test_get(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
fetched = await repo.get(created["feedback_id"])
assert fetched is not None
assert fetched["feedback_id"] == created["feedback_id"]
assert fetched["rating"] == 1
await _cleanup()
@pytest.mark.anyio
async def test_get_nonexistent(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
assert await repo.get("nonexistent") is None
await _cleanup()
@pytest.mark.anyio
async def test_list_by_run(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
results = await repo.list_by_run("t1", "r1", user_id=None)
assert len(results) == 2
assert all(r["run_id"] == "r1" for r in results)
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1)
await repo.create(run_id="r2", thread_id="t1", rating=-1)
await repo.create(run_id="r3", thread_id="t2", rating=1)
results = await repo.list_by_thread("t1")
assert len(results) == 2
assert all(r["thread_id"] == "t1" for r in results)
await _cleanup()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
deleted = await repo.delete(created["feedback_id"])
assert deleted is True
assert await repo.get(created["feedback_id"]) is None
await _cleanup()
@pytest.mark.anyio
async def test_delete_nonexistent(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
deleted = await repo.delete("nonexistent")
assert deleted is False
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_by_run(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
stats = await repo.aggregate_by_run("t1", "r1")
assert stats["total"] == 3
assert stats["positive"] == 2
assert stats["negative"] == 1
assert stats["run_id"] == "r1"
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_empty(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
stats = await repo.aggregate_by_run("t1", "r1")
assert stats["total"] == 0
assert stats["positive"] == 0
assert stats["negative"] == 0
await _cleanup()
@pytest.mark.anyio
async def test_upsert_creates_new(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
assert record["rating"] == 1
assert record["feedback_id"]
assert record["user_id"] == "u1"
await _cleanup()
@pytest.mark.anyio
async def test_upsert_updates_existing(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
assert second["feedback_id"] == first["feedback_id"]
assert second["rating"] == -1
assert second["comment"] == "changed my mind"
await _cleanup()
@pytest.mark.anyio
async def test_upsert_different_users_separate(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
assert r1["feedback_id"] != r2["feedback_id"]
assert r1["rating"] == 1
assert r2["rating"] == -1
await _cleanup()
@pytest.mark.anyio
async def test_upsert_invalid_rating(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
with pytest.raises(ValueError):
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
await _cleanup()
@pytest.mark.anyio
async def test_delete_by_run(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
assert deleted is True
results = await repo.list_by_run("t1", "r1", user_id="u1")
assert len(results) == 0
await _cleanup()
@pytest.mark.anyio
async def test_delete_by_run_nonexistent(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
assert deleted is False
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_grouped(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
assert "r1" in grouped
assert "r2" in grouped
assert "r3" not in grouped
assert grouped["r1"]["rating"] == 1
assert grouped["r2"]["rating"] == -1
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_grouped_empty(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
assert grouped == {}
await _cleanup()
# -- Follow-up association --
class TestFollowUpAssociation:
@pytest.mark.anyio
async def test_run_records_follow_up_via_memory_store(self):
"""MemoryRunStore stores follow_up_to_run_id in kwargs."""
from deerflow.runtime.runs.store.memory import MemoryRunStore
store = MemoryRunStore()
await store.put("r1", thread_id="t1", status="success")
# MemoryRunStore doesn't have follow_up_to_run_id as a top-level param,
# but it can be passed via metadata
await store.put("r2", thread_id="t1", metadata={"follow_up_to_run_id": "r1"})
run = await store.get("r2")
assert run["metadata"]["follow_up_to_run_id"] == "r1"
@pytest.mark.anyio
async def test_human_message_has_follow_up_metadata(self):
"""human_message event metadata includes follow_up_to_run_id."""
from deerflow.runtime.events.store.memory import MemoryRunEventStore
event_store = MemoryRunEventStore()
await event_store.put(
thread_id="t1",
run_id="r2",
event_type="human_message",
category="message",
content="Tell me more about that",
metadata={"follow_up_to_run_id": "r1"},
)
messages = await event_store.list_messages("t1")
assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1"
@pytest.mark.anyio
async def test_follow_up_auto_detection_logic(self):
"""Simulate the auto-detection: latest successful run becomes follow_up_to."""
from deerflow.runtime.runs.store.memory import MemoryRunStore
store = MemoryRunStore()
await store.put("r1", thread_id="t1", status="success")
await store.put("r2", thread_id="t1", status="error")
# Auto-detect: list_by_thread returns newest first
recent = await store.list_by_thread("t1", limit=1)
follow_up = None
if recent and recent[0].get("status") == "success":
follow_up = recent[0]["run_id"]
# r2 (error) is newest, so no follow_up detected
assert follow_up is None
# Now add a successful run
await store.put("r3", thread_id="t1", status="success")
recent = await store.list_by_thread("t1", limit=1)
follow_up = None
if recent and recent[0].get("status") == "success":
follow_up = recent[0]["run_id"]
assert follow_up == "r3"

View File

@ -1,342 +0,0 @@
"""Tests for app.gateway.services — run lifecycle service layer."""
from __future__ import annotations
import json
def test_format_sse_basic():
from app.gateway.services import format_sse
frame = format_sse("metadata", {"run_id": "abc"})
assert frame.startswith("event: metadata\n")
assert "data: " in frame
parsed = json.loads(frame.split("data: ")[1].split("\n")[0])
assert parsed["run_id"] == "abc"
def test_format_sse_with_event_id():
from app.gateway.services import format_sse
frame = format_sse("metadata", {"run_id": "abc"}, event_id="123-0")
assert "id: 123-0" in frame
def test_format_sse_end_event_null():
from app.gateway.services import format_sse
frame = format_sse("end", None)
assert "data: null" in frame
def test_format_sse_no_event_id():
from app.gateway.services import format_sse
frame = format_sse("values", {"x": 1})
assert "id:" not in frame
def test_normalize_stream_modes_none():
from app.gateway.services import normalize_stream_modes
assert normalize_stream_modes(None) == ["values"]
def test_normalize_stream_modes_string():
from app.gateway.services import normalize_stream_modes
assert normalize_stream_modes("messages-tuple") == ["messages-tuple"]
def test_normalize_stream_modes_list():
from app.gateway.services import normalize_stream_modes
assert normalize_stream_modes(["values", "messages-tuple"]) == ["values", "messages-tuple"]
def test_normalize_stream_modes_empty_list():
from app.gateway.services import normalize_stream_modes
assert normalize_stream_modes([]) == ["values"]
def test_normalize_input_none():
from app.gateway.services import normalize_input
assert normalize_input(None) == {}
def test_normalize_input_with_messages():
from app.gateway.services import normalize_input
result = normalize_input({"messages": [{"role": "user", "content": "hi"}]})
assert len(result["messages"]) == 1
assert result["messages"][0].content == "hi"
def test_normalize_input_passthrough():
from app.gateway.services import normalize_input
result = normalize_input({"custom_key": "value"})
assert result == {"custom_key": "value"}
def test_build_run_config_basic():
from app.gateway.services import build_run_config
config = build_run_config("thread-1", None, None)
assert config["configurable"]["thread_id"] == "thread-1"
assert config["recursion_limit"] == 100
def test_build_run_config_with_overrides():
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"configurable": {"model_name": "gpt-4"}, "tags": ["test"]},
{"user": "alice"},
)
assert config["configurable"]["model_name"] == "gpt-4"
assert config["tags"] == ["test"]
assert config["metadata"]["user"] == "alice"
# ---------------------------------------------------------------------------
# Regression tests for issue #1644:
# assistant_id not mapped to agent_name → custom agent SOUL.md never loaded
# ---------------------------------------------------------------------------
def test_build_run_config_custom_agent_injects_agent_name():
"""Custom assistant_id must be forwarded as configurable['agent_name']."""
from app.gateway.services import build_run_config
config = build_run_config("thread-1", None, None, assistant_id="finalis")
assert config["configurable"]["agent_name"] == "finalis"
def test_build_run_config_lead_agent_no_agent_name():
"""'lead_agent' assistant_id must NOT inject configurable['agent_name']."""
from app.gateway.services import build_run_config
config = build_run_config("thread-1", None, None, assistant_id="lead_agent")
assert "agent_name" not in config["configurable"]
def test_build_run_config_none_assistant_id_no_agent_name():
"""None assistant_id must NOT inject configurable['agent_name']."""
from app.gateway.services import build_run_config
config = build_run_config("thread-1", None, None, assistant_id=None)
assert "agent_name" not in config["configurable"]
def test_build_run_config_explicit_agent_name_not_overwritten():
"""An explicit configurable['agent_name'] in the request must take precedence."""
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"configurable": {"agent_name": "explicit-agent"}},
None,
assistant_id="other-agent",
)
assert config["configurable"]["agent_name"] == "explicit-agent"
def test_resolve_agent_factory_returns_make_lead_agent():
"""resolve_agent_factory always returns make_lead_agent regardless of assistant_id."""
from app.gateway.services import resolve_agent_factory
from deerflow.agents.lead_agent.agent import make_lead_agent
assert resolve_agent_factory(None) is make_lead_agent
assert resolve_agent_factory("lead_agent") is make_lead_agent
assert resolve_agent_factory("finalis") is make_lead_agent
assert resolve_agent_factory("custom-agent-123") is make_lead_agent
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Regression tests for issue #1699:
# context field in langgraph-compat requests not merged into configurable
# ---------------------------------------------------------------------------
def test_run_create_request_accepts_context():
"""RunCreateRequest must accept the ``context`` field without dropping it."""
from app.gateway.routers.thread_runs import RunCreateRequest
body = RunCreateRequest(
input={"messages": [{"role": "user", "content": "hi"}]},
context={
"model_name": "deepseek-v3",
"thinking_enabled": True,
"is_plan_mode": True,
"subagent_enabled": True,
"thread_id": "some-thread-id",
},
)
assert body.context is not None
assert body.context["model_name"] == "deepseek-v3"
assert body.context["is_plan_mode"] is True
assert body.context["subagent_enabled"] is True
def test_run_create_request_context_defaults_to_none():
"""RunCreateRequest without context should default to None (backward compat)."""
from app.gateway.routers.thread_runs import RunCreateRequest
body = RunCreateRequest(input=None)
assert body.context is None
def test_context_merges_into_configurable():
"""Context values must be merged into config['configurable'] by start_run.
Since start_run is async and requires many dependencies, we test the
merging logic directly by simulating what start_run does.
"""
from app.gateway.services import build_run_config
# Simulate the context merging logic from start_run
config = build_run_config("thread-1", None, None)
context = {
"model_name": "deepseek-v3",
"mode": "ultra",
"reasoning_effort": "high",
"thinking_enabled": True,
"is_plan_mode": True,
"subagent_enabled": True,
"max_concurrent_subagents": 5,
"thread_id": "should-be-ignored",
}
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
assert config["configurable"]["model_name"] == "deepseek-v3"
assert config["configurable"]["thinking_enabled"] is True
assert config["configurable"]["is_plan_mode"] is True
assert config["configurable"]["subagent_enabled"] is True
assert config["configurable"]["max_concurrent_subagents"] == 5
assert config["configurable"]["reasoning_effort"] == "high"
assert config["configurable"]["mode"] == "ultra"
# thread_id from context should NOT override the one from build_run_config
assert config["configurable"]["thread_id"] == "thread-1"
# Non-allowlisted keys should not appear
assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS}
def test_context_does_not_override_existing_configurable():
"""Values already in config.configurable must NOT be overridden by context."""
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"configurable": {"model_name": "gpt-4", "is_plan_mode": False}},
None,
)
context = {
"model_name": "deepseek-v3",
"is_plan_mode": True,
"subagent_enabled": True,
}
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
# Existing values must NOT be overridden
assert config["configurable"]["model_name"] == "gpt-4"
assert config["configurable"]["is_plan_mode"] is False
# New values should be added
assert config["configurable"]["subagent_enabled"] is True
# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
def test_build_run_config_with_context():
"""When caller sends 'context', prefer it over 'configurable'."""
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"context": {"user_id": "u-42", "thread_id": "thread-1"}},
None,
)
assert "context" in config
assert config["context"]["user_id"] == "u-42"
assert "configurable" not in config
assert config["recursion_limit"] == 100
def test_build_run_config_context_plus_configurable_warns(caplog):
"""When caller sends both 'context' and 'configurable', prefer 'context' and log a warning."""
import logging
from app.gateway.services import build_run_config
with caplog.at_level(logging.WARNING, logger="app.gateway.services"):
config = build_run_config(
"thread-1",
{
"context": {"user_id": "u-42"},
"configurable": {"model_name": "gpt-4"},
},
None,
)
assert "context" in config
assert config["context"]["user_id"] == "u-42"
assert "configurable" not in config
assert any("both 'context' and 'configurable'" in r.message for r in caplog.records)
def test_build_run_config_context_passthrough_other_keys():
"""Non-conflicting keys from request_config are still passed through when context is used."""
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"context": {"thread_id": "thread-1"}, "tags": ["prod"]},
None,
)
assert config["context"]["thread_id"] == "thread-1"
assert "configurable" not in config
assert config["tags"] == ["prod"]
def test_build_run_config_no_request_config():
"""When request_config is None, fall back to basic configurable with thread_id."""
from app.gateway.services import build_run_config
config = build_run_config("thread-abc", None, None)
assert config["configurable"] == {"thread_id": "thread-abc"}
assert "context" not in config

View File

@ -1,156 +0,0 @@
"""Owner isolation tests for MemoryThreadMetaStore.
Mirrors the SQL-backed tests in test_owner_isolation.py but exercises
the in-memory LangGraph Store backend used when database.backend=memory.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from langgraph.store.memory import InMemoryStore
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.runtime.user_context import reset_current_user, set_current_user
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
def _as_user(user):
class _Ctx:
def __enter__(self):
self._token = set_current_user(user)
return user
def __exit__(self, *exc):
reset_current_user(self._token)
return _Ctx()
@pytest.fixture
def store():
return MemoryThreadMetaStore(InMemoryStore())
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_search_isolation(store):
"""search() returns only threads owned by the current user."""
with _as_user(USER_A):
await store.create("t-alpha", display_name="A's thread")
with _as_user(USER_B):
await store.create("t-beta", display_name="B's thread")
with _as_user(USER_A):
results = await store.search()
assert [r["thread_id"] for r in results] == ["t-alpha"]
with _as_user(USER_B):
results = await store.search()
assert [r["thread_id"] for r in results] == ["t-beta"]
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_get_isolation(store):
"""get() returns None for threads owned by another user."""
with _as_user(USER_A):
await store.create("t-alpha", display_name="A's thread")
with _as_user(USER_B):
assert await store.get("t-alpha") is None
with _as_user(USER_A):
result = await store.get("t-alpha")
assert result is not None
assert result["display_name"] == "A's thread"
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_update_display_name_denied(store):
"""User B cannot rename User A's thread."""
with _as_user(USER_A):
await store.create("t-alpha", display_name="original")
with _as_user(USER_B):
await store.update_display_name("t-alpha", "hacked")
with _as_user(USER_A):
row = await store.get("t-alpha")
assert row is not None
assert row["display_name"] == "original"
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_update_status_denied(store):
"""User B cannot change status of User A's thread."""
with _as_user(USER_A):
await store.create("t-alpha")
with _as_user(USER_B):
await store.update_status("t-alpha", "error")
with _as_user(USER_A):
row = await store.get("t-alpha")
assert row is not None
assert row["status"] == "idle"
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_update_metadata_denied(store):
"""User B cannot modify metadata of User A's thread."""
with _as_user(USER_A):
await store.create("t-alpha", metadata={"key": "original"})
with _as_user(USER_B):
await store.update_metadata("t-alpha", {"key": "hacked"})
with _as_user(USER_A):
row = await store.get("t-alpha")
assert row is not None
assert row["metadata"]["key"] == "original"
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_delete_denied(store):
"""User B cannot delete User A's thread."""
with _as_user(USER_A):
await store.create("t-alpha")
with _as_user(USER_B):
await store.delete("t-alpha")
with _as_user(USER_A):
row = await store.get("t-alpha")
assert row is not None
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_no_context_raises(store):
"""Calling methods without user context raises RuntimeError."""
with pytest.raises(RuntimeError, match="no user context is set"):
await store.search()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_explicit_none_bypasses_filter(store):
"""user_id=None bypasses isolation (migration/CLI escape hatch)."""
with _as_user(USER_A):
await store.create("t-alpha")
with _as_user(USER_B):
await store.create("t-beta")
all_rows = await store.search(user_id=None)
assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"}
row = await store.get("t-alpha", user_id=None)
assert row is not None

View File

@ -1,465 +0,0 @@
"""Cross-user isolation tests — non-negotiable safety gate.
Mirrors TC-API-17..20 from backend/docs/AUTH_TEST_PLAN.md. A failure
here means users can see each other's data; PR must not merge.
Architecture note
-----------------
These tests bypass the HTTP layer and exercise the storage-layer
owner filter directly by switching the ``user_context`` contextvar
between two users. The safety property under test is:
After a repository write with user_id=A, a subsequent read with
user_id=B must not return the row, and vice versa.
The HTTP layer is covered by test_auth_middleware.py, which proves
that a request cookie reaches the ``set_current_user`` call. Together
the two suites prove the full chain:
cookie middleware contextvar repository isolation
Every test in this file opts out of the autouse contextvar fixture
(``@pytest.mark.no_auto_user``) so it can set the contextvar to the
specific users it cares about.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from deerflow.runtime.user_context import (
reset_current_user,
set_current_user,
)
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
async def _make_engines(tmp_path):
"""Initialize the shared engine against a per-test SQLite DB.
Returns a cleanup coroutine the caller should await at the end.
"""
from deerflow.persistence.engine import close_engine, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'isolation.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return close_engine
def _as_user(user):
"""Context manager-like helper that set/reset the contextvar."""
class _Ctx:
def __enter__(self):
self._token = set_current_user(user)
return user
def __exit__(self, *exc):
reset_current_user(self._token)
return _Ctx()
# ── TC-API-17 — threads_meta isolation ────────────────────────────────────
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_thread_meta_cross_user_isolation(tmp_path):
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.thread_meta import ThreadMetaRepository
cleanup = await _make_engines(tmp_path)
try:
repo = ThreadMetaRepository(get_session_factory())
# User A creates a thread.
with _as_user(USER_A):
await repo.create("t-alpha", display_name="A's private thread")
# User B creates a thread.
with _as_user(USER_B):
await repo.create("t-beta", display_name="B's private thread")
# User A must see only A's thread.
with _as_user(USER_A):
a_view = await repo.get("t-alpha")
assert a_view is not None
assert a_view["display_name"] == "A's private thread"
# CRITICAL: User A must NOT see B's thread.
leaked = await repo.get("t-beta")
assert leaked is None, f"User A leaked User B's thread: {leaked}"
# Search should only return A's threads.
results = await repo.search()
assert [r["thread_id"] for r in results] == ["t-alpha"]
# User B must see only B's thread.
with _as_user(USER_B):
b_view = await repo.get("t-beta")
assert b_view is not None
assert b_view["display_name"] == "B's private thread"
leaked = await repo.get("t-alpha")
assert leaked is None, f"User B leaked User A's thread: {leaked}"
results = await repo.search()
assert [r["thread_id"] for r in results] == ["t-beta"]
finally:
await cleanup()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_thread_meta_cross_user_mutation_denied(tmp_path):
"""User B cannot update or delete a thread owned by User A."""
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.thread_meta import ThreadMetaRepository
cleanup = await _make_engines(tmp_path)
try:
repo = ThreadMetaRepository(get_session_factory())
with _as_user(USER_A):
await repo.create("t-alpha", display_name="original")
# User B tries to rename A's thread — must be a no-op.
with _as_user(USER_B):
await repo.update_display_name("t-alpha", "hacked")
# Verify the row is unchanged from A's perspective.
with _as_user(USER_A):
row = await repo.get("t-alpha")
assert row is not None
assert row["display_name"] == "original"
# User B tries to delete A's thread — must be a no-op.
with _as_user(USER_B):
await repo.delete("t-alpha")
# A's thread still exists.
with _as_user(USER_A):
row = await repo.get("t-alpha")
assert row is not None
finally:
await cleanup()
# ── TC-API-18 — runs isolation ────────────────────────────────────────────
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_runs_cross_user_isolation(tmp_path):
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.run import RunRepository
cleanup = await _make_engines(tmp_path)
try:
repo = RunRepository(get_session_factory())
with _as_user(USER_A):
await repo.put("run-a1", thread_id="t-alpha")
await repo.put("run-a2", thread_id="t-alpha")
with _as_user(USER_B):
await repo.put("run-b1", thread_id="t-beta")
# User A must see only A's runs.
with _as_user(USER_A):
r = await repo.get("run-a1")
assert r is not None
assert r["run_id"] == "run-a1"
leaked = await repo.get("run-b1")
assert leaked is None, "User A leaked User B's run"
a_runs = await repo.list_by_thread("t-alpha")
assert {r["run_id"] for r in a_runs} == {"run-a1", "run-a2"}
# Listing B's thread from A's perspective: empty
empty = await repo.list_by_thread("t-beta")
assert empty == []
# User B must see only B's runs.
with _as_user(USER_B):
leaked = await repo.get("run-a1")
assert leaked is None, "User B leaked User A's run"
b_runs = await repo.list_by_thread("t-beta")
assert [r["run_id"] for r in b_runs] == ["run-b1"]
finally:
await cleanup()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_runs_cross_user_delete_denied(tmp_path):
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.run import RunRepository
cleanup = await _make_engines(tmp_path)
try:
repo = RunRepository(get_session_factory())
with _as_user(USER_A):
await repo.put("run-a1", thread_id="t-alpha")
# User B tries to delete A's run — no-op.
with _as_user(USER_B):
await repo.delete("run-a1")
# A's run still exists.
with _as_user(USER_A):
row = await repo.get("run-a1")
assert row is not None
finally:
await cleanup()
# ── TC-API-19 — run_events isolation (CRITICAL: content leak) ─────────────
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_run_events_cross_user_isolation(tmp_path):
"""run_events holds raw conversation content — most sensitive leak vector."""
from deerflow.persistence.engine import get_session_factory
from deerflow.runtime.events.store.db import DbRunEventStore
cleanup = await _make_engines(tmp_path)
try:
store = DbRunEventStore(get_session_factory())
with _as_user(USER_A):
await store.put(
thread_id="t-alpha",
run_id="run-a1",
event_type="human_message",
category="message",
content="User A private question",
)
await store.put(
thread_id="t-alpha",
run_id="run-a1",
event_type="ai_message",
category="message",
content="User A private answer",
)
with _as_user(USER_B):
await store.put(
thread_id="t-beta",
run_id="run-b1",
event_type="human_message",
category="message",
content="User B private question",
)
# User A must see only A's events — CRITICAL.
with _as_user(USER_A):
msgs = await store.list_messages("t-alpha")
contents = [m["content"] for m in msgs]
assert "User A private question" in contents
assert "User A private answer" in contents
# CRITICAL: User B's content must not appear.
assert "User B private question" not in contents
# Attempt to read B's thread by guessing thread_id.
leaked = await store.list_messages("t-beta")
assert leaked == [], f"User A leaked User B's messages: {leaked}"
leaked_events = await store.list_events("t-beta", "run-b1")
assert leaked_events == [], "User A leaked User B's events"
# count_messages must also be zero for B's thread from A's view.
count = await store.count_messages("t-beta")
assert count == 0
# User B must see only B's events.
with _as_user(USER_B):
msgs = await store.list_messages("t-beta")
contents = [m["content"] for m in msgs]
assert "User B private question" in contents
assert "User A private question" not in contents
assert "User A private answer" not in contents
count = await store.count_messages("t-alpha")
assert count == 0
finally:
await cleanup()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_run_events_cross_user_delete_denied(tmp_path):
"""User B cannot delete User A's event stream."""
from deerflow.persistence.engine import get_session_factory
from deerflow.runtime.events.store.db import DbRunEventStore
cleanup = await _make_engines(tmp_path)
try:
store = DbRunEventStore(get_session_factory())
with _as_user(USER_A):
await store.put(
thread_id="t-alpha",
run_id="run-a1",
event_type="human_message",
category="message",
content="hello",
)
# User B tries to wipe A's thread events.
with _as_user(USER_B):
removed = await store.delete_by_thread("t-alpha")
assert removed == 0, f"User B deleted {removed} of User A's events"
# A's events still exist.
with _as_user(USER_A):
count = await store.count_messages("t-alpha")
assert count == 1
finally:
await cleanup()
# ── TC-API-20 — feedback isolation ────────────────────────────────────────
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_feedback_cross_user_isolation(tmp_path):
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.feedback import FeedbackRepository
cleanup = await _make_engines(tmp_path)
try:
repo = FeedbackRepository(get_session_factory())
# User A submits positive feedback.
with _as_user(USER_A):
a_feedback = await repo.create(
run_id="run-a1",
thread_id="t-alpha",
rating=1,
comment="A liked this",
)
# User B submits negative feedback.
with _as_user(USER_B):
b_feedback = await repo.create(
run_id="run-b1",
thread_id="t-beta",
rating=-1,
comment="B disliked this",
)
# User A must see only A's feedback.
with _as_user(USER_A):
retrieved = await repo.get(a_feedback["feedback_id"])
assert retrieved is not None
assert retrieved["comment"] == "A liked this"
# CRITICAL: cannot read B's feedback by id.
leaked = await repo.get(b_feedback["feedback_id"])
assert leaked is None, "User A leaked User B's feedback"
# list_by_run for B's run must be empty.
empty = await repo.list_by_run("t-beta", "run-b1")
assert empty == []
# User B must see only B's feedback.
with _as_user(USER_B):
leaked = await repo.get(a_feedback["feedback_id"])
assert leaked is None, "User B leaked User A's feedback"
b_list = await repo.list_by_run("t-beta", "run-b1")
assert len(b_list) == 1
assert b_list[0]["comment"] == "B disliked this"
finally:
await cleanup()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_feedback_cross_user_delete_denied(tmp_path):
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.feedback import FeedbackRepository
cleanup = await _make_engines(tmp_path)
try:
repo = FeedbackRepository(get_session_factory())
with _as_user(USER_A):
fb = await repo.create(run_id="run-a1", thread_id="t-alpha", rating=1)
# User B tries to delete A's feedback — must return False (no-op).
with _as_user(USER_B):
deleted = await repo.delete(fb["feedback_id"])
assert deleted is False, "User B deleted User A's feedback"
# A's feedback still retrievable.
with _as_user(USER_A):
row = await repo.get(fb["feedback_id"])
assert row is not None
finally:
await cleanup()
# ── Regression: AUTO sentinel without contextvar must raise ───────────────
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_repository_without_context_raises(tmp_path):
"""Defense-in-depth: calling repo methods without a user context errors."""
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.thread_meta import ThreadMetaRepository
cleanup = await _make_engines(tmp_path)
try:
repo = ThreadMetaRepository(get_session_factory())
# Contextvar is explicitly unset under @pytest.mark.no_auto_user.
with pytest.raises(RuntimeError, match="no user context is set"):
await repo.get("anything")
finally:
await cleanup()
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_explicit_none_bypasses_filter(tmp_path):
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.thread_meta import ThreadMetaRepository
cleanup = await _make_engines(tmp_path)
try:
repo = ThreadMetaRepository(get_session_factory())
# Seed data as two different users.
with _as_user(USER_A):
await repo.create("t-alpha")
with _as_user(USER_B):
await repo.create("t-beta")
# Migration-style read: no contextvar, explicit None bypass.
all_rows = await repo.search(user_id=None)
thread_ids = {r["thread_id"] for r in all_rows}
assert thread_ids == {"t-alpha", "t-beta"}
# Explicit get with None does not apply the filter either.
row_a = await repo.get("t-alpha", user_id=None)
assert row_a is not None
row_b = await repo.get("t-beta", user_id=None)
assert row_b is not None
finally:
await cleanup()

View File

@ -1,233 +0,0 @@
"""Tests for the persistence layer scaffolding.
Tests:
1. DatabaseConfig property derivation (paths, URLs)
2. MemoryRunStore CRUD + user_id filtering
3. Base.to_dict() via inspect mixin
4. Engine init/close lifecycle (memory + SQLite)
5. Postgres missing-dep error message
"""
from datetime import UTC, datetime
import pytest
from deerflow.config.database_config import DatabaseConfig
from deerflow.runtime.runs.store.memory import MemoryRunStore
# -- DatabaseConfig --
class TestDatabaseConfig:
def test_defaults(self):
c = DatabaseConfig()
assert c.backend == "memory"
assert c.pool_size == 5
def test_sqlite_paths_unified(self):
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
assert c.sqlite_path.endswith("deerflow.db")
assert "mydata" in c.sqlite_path
# Backward-compatible aliases point to the same file
assert c.checkpointer_sqlite_path == c.sqlite_path
assert c.app_sqlite_path == c.sqlite_path
def test_app_sqlalchemy_url_sqlite(self):
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
url = c.app_sqlalchemy_url
assert url.startswith("sqlite+aiosqlite:///")
assert "deerflow.db" in url
def test_app_sqlalchemy_url_postgres(self):
c = DatabaseConfig(
backend="postgres",
postgres_url="postgresql://u:p@h:5432/db",
)
url = c.app_sqlalchemy_url
assert url.startswith("postgresql+asyncpg://")
assert "u:p@h:5432/db" in url
def test_app_sqlalchemy_url_postgres_already_asyncpg(self):
c = DatabaseConfig(
backend="postgres",
postgres_url="postgresql+asyncpg://u:p@h:5432/db",
)
url = c.app_sqlalchemy_url
assert url.count("asyncpg") == 1
def test_memory_has_no_url(self):
c = DatabaseConfig(backend="memory")
with pytest.raises(ValueError, match="No SQLAlchemy URL"):
_ = c.app_sqlalchemy_url
# -- MemoryRunStore --
class TestMemoryRunStore:
@pytest.fixture
def store(self):
return MemoryRunStore()
@pytest.mark.anyio
async def test_put_and_get(self, store):
await store.put("r1", thread_id="t1", status="pending")
row = await store.get("r1")
assert row is not None
assert row["run_id"] == "r1"
assert row["status"] == "pending"
@pytest.mark.anyio
async def test_get_missing_returns_none(self, store):
assert await store.get("nope") is None
@pytest.mark.anyio
async def test_update_status(self, store):
await store.put("r1", thread_id="t1")
await store.update_status("r1", "running")
assert (await store.get("r1"))["status"] == "running"
@pytest.mark.anyio
async def test_update_status_with_error(self, store):
await store.put("r1", thread_id="t1")
await store.update_status("r1", "error", error="boom")
row = await store.get("r1")
assert row["status"] == "error"
assert row["error"] == "boom"
@pytest.mark.anyio
async def test_list_by_thread(self, store):
await store.put("r1", thread_id="t1")
await store.put("r2", thread_id="t1")
await store.put("r3", thread_id="t2")
rows = await store.list_by_thread("t1")
assert len(rows) == 2
assert all(r["thread_id"] == "t1" for r in rows)
@pytest.mark.anyio
async def test_list_by_thread_owner_filter(self, store):
await store.put("r1", thread_id="t1", user_id="alice")
await store.put("r2", thread_id="t1", user_id="bob")
rows = await store.list_by_thread("t1", user_id="alice")
assert len(rows) == 1
assert rows[0]["user_id"] == "alice"
@pytest.mark.anyio
async def test_owner_none_returns_all(self, store):
await store.put("r1", thread_id="t1", user_id="alice")
await store.put("r2", thread_id="t1", user_id="bob")
rows = await store.list_by_thread("t1", user_id=None)
assert len(rows) == 2
@pytest.mark.anyio
async def test_delete(self, store):
await store.put("r1", thread_id="t1")
await store.delete("r1")
assert await store.get("r1") is None
@pytest.mark.anyio
async def test_delete_nonexistent_is_noop(self, store):
await store.delete("nope") # should not raise
@pytest.mark.anyio
async def test_list_pending(self, store):
await store.put("r1", thread_id="t1", status="pending")
await store.put("r2", thread_id="t1", status="running")
await store.put("r3", thread_id="t2", status="pending")
pending = await store.list_pending()
assert len(pending) == 2
assert all(r["status"] == "pending" for r in pending)
@pytest.mark.anyio
async def test_list_pending_respects_before(self, store):
past = "2020-01-01T00:00:00+00:00"
future = "2099-01-01T00:00:00+00:00"
await store.put("r1", thread_id="t1", status="pending", created_at=past)
await store.put("r2", thread_id="t1", status="pending", created_at=future)
pending = await store.list_pending(before=datetime.now(UTC).isoformat())
assert len(pending) == 1
assert pending[0]["run_id"] == "r1"
@pytest.mark.anyio
async def test_list_pending_fifo_order(self, store):
await store.put("r2", thread_id="t1", status="pending", created_at="2024-01-02T00:00:00+00:00")
await store.put("r1", thread_id="t1", status="pending", created_at="2024-01-01T00:00:00+00:00")
pending = await store.list_pending()
assert pending[0]["run_id"] == "r1"
# -- Base.to_dict mixin --
class TestBaseToDictMixin:
@pytest.mark.anyio
async def test_to_dict_and_exclude(self, tmp_path):
"""Create a temp SQLite DB with a minimal model, verify to_dict."""
from sqlalchemy import String
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class _Tmp(Base):
__tablename__ = "_tmp_test"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(128))
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}")
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
sf = async_sessionmaker(engine, expire_on_commit=False)
async with sf() as session:
session.add(_Tmp(id="1", name="hello"))
await session.commit()
obj = await session.get(_Tmp, "1")
assert obj.to_dict() == {"id": "1", "name": "hello"}
assert obj.to_dict(exclude={"name"}) == {"id": "1"}
assert "_Tmp" in repr(obj)
await engine.dispose()
# -- Engine lifecycle --
class TestEngineLifecycle:
@pytest.mark.anyio
async def test_memory_is_noop(self):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
await init_engine("memory")
assert get_session_factory() is None
await close_engine()
@pytest.mark.anyio
async def test_sqlite_creates_engine(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
sf = get_session_factory()
assert sf is not None
async with sf() as session:
assert session is not None
await close_engine()
assert get_session_factory() is None
@pytest.mark.anyio
async def test_postgres_without_asyncpg_gives_actionable_error(self):
"""If asyncpg is not installed, error message tells user what to do."""
from deerflow.persistence.engine import init_engine
try:
import asyncpg # noqa: F401
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
except ImportError:
# asyncpg is not installed — this is the expected state for this test.
# We proceed to verify that init_engine raises an actionable ImportError.
pass # noqa: S110 — intentionally ignored
with pytest.raises(ImportError, match="uv sync --extra postgres"):
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")

View File

@ -1,500 +0,0 @@
"""Tests for RunEventStore contract across all backends.
Uses a helper to create the store for each backend type.
Memory tests run directly; DB and JSONL tests create stores inside each test.
"""
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
@pytest.fixture
def store():
return MemoryRunEventStore()
# -- Basic write and query --
class TestPutAndSeq:
@pytest.mark.anyio
async def test_put_returns_dict_with_seq(self, store):
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello")
assert "seq" in record
assert record["seq"] == 1
assert record["thread_id"] == "t1"
assert record["run_id"] == "r1"
assert record["event_type"] == "human_message"
assert record["category"] == "message"
assert record["content"] == "hello"
assert "created_at" in record
@pytest.mark.anyio
async def test_seq_strictly_increasing_same_thread(self, store):
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
r2 = await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
r3 = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
assert r1["seq"] == 1
assert r2["seq"] == 2
assert r3["seq"] == 3
@pytest.mark.anyio
async def test_seq_independent_across_threads(self, store):
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
r2 = await store.put(thread_id="t2", run_id="r2", event_type="human_message", category="message")
assert r1["seq"] == 1
assert r2["seq"] == 1
@pytest.mark.anyio
async def test_put_respects_provided_created_at(self, store):
ts = "2024-06-01T12:00:00+00:00"
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", created_at=ts)
assert record["created_at"] == ts
@pytest.mark.anyio
async def test_put_metadata_preserved(self, store):
meta = {"model": "gpt-4", "tokens": 100}
record = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", metadata=meta)
assert record["metadata"] == meta
# -- list_messages --
class TestListMessages:
@pytest.mark.anyio
async def test_only_returns_message_category(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["category"] == "message"
@pytest.mark.anyio
async def test_ascending_seq_order(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="first")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="second")
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="third")
messages = await store.list_messages("t1")
seqs = [m["seq"] for m in messages]
assert seqs == sorted(seqs)
@pytest.mark.anyio
async def test_before_seq_pagination(self, store):
for i in range(10):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
messages = await store.list_messages("t1", before_seq=6, limit=3)
assert len(messages) == 3
assert [m["seq"] for m in messages] == [3, 4, 5]
@pytest.mark.anyio
async def test_after_seq_pagination(self, store):
for i in range(10):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
messages = await store.list_messages("t1", after_seq=7, limit=3)
assert len(messages) == 3
assert [m["seq"] for m in messages] == [8, 9, 10]
@pytest.mark.anyio
async def test_limit_restricts_count(self, store):
for _ in range(20):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
messages = await store.list_messages("t1", limit=5)
assert len(messages) == 5
@pytest.mark.anyio
async def test_cross_run_unified_ordering(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
messages = await store.list_messages("t1")
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
assert messages[0]["run_id"] == "r1"
assert messages[2]["run_id"] == "r2"
@pytest.mark.anyio
async def test_default_returns_latest(self, store):
for _ in range(10):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
messages = await store.list_messages("t1", limit=3)
assert [m["seq"] for m in messages] == [8, 9, 10]
# -- list_events --
class TestListEvents:
@pytest.mark.anyio
async def test_returns_all_categories_for_run(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
events = await store.list_events("t1", "r1")
assert len(events) == 3
@pytest.mark.anyio
async def test_event_types_filter(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="llm_start", category="trace")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r1", event_type="tool_start", category="trace")
events = await store.list_events("t1", "r1", event_types=["llm_end"])
assert len(events) == 1
assert events[0]["event_type"] == "llm_end"
@pytest.mark.anyio
async def test_only_returns_specified_run(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
events = await store.list_events("t1", "r1")
assert len(events) == 1
assert events[0]["run_id"] == "r1"
# -- list_messages_by_run --
class TestListMessagesByRun:
@pytest.mark.anyio
async def test_only_messages_for_specified_run(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
messages = await store.list_messages_by_run("t1", "r1")
assert len(messages) == 1
assert messages[0]["run_id"] == "r1"
assert messages[0]["category"] == "message"
# -- count_messages --
class TestCountMessages:
@pytest.mark.anyio
async def test_counts_only_message_category(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
assert await store.count_messages("t1") == 2
# -- put_batch --
class TestPutBatch:
@pytest.mark.anyio
async def test_batch_assigns_seq(self, store):
events = [
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message", "content": "a"},
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message", "content": "b"},
{"thread_id": "t1", "run_id": "r1", "event_type": "llm_end", "category": "trace"},
]
results = await store.put_batch(events)
assert len(results) == 3
assert all("seq" in r for r in results)
@pytest.mark.anyio
async def test_batch_seq_strictly_increasing(self, store):
events = [
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message"},
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message"},
]
results = await store.put_batch(events)
assert results[0]["seq"] == 1
assert results[1]["seq"] == 2
# -- delete --
class TestDelete:
@pytest.mark.anyio
async def test_delete_by_thread(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
count = await store.delete_by_thread("t1")
assert count == 3
assert await store.list_messages("t1") == []
assert await store.count_messages("t1") == 0
@pytest.mark.anyio
async def test_delete_by_run(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
count = await store.delete_by_run("t1", "r2")
assert count == 2
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["run_id"] == "r1"
@pytest.mark.anyio
async def test_delete_nonexistent_thread_returns_zero(self, store):
assert await store.delete_by_thread("nope") == 0
@pytest.mark.anyio
async def test_delete_nonexistent_run_returns_zero(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
assert await store.delete_by_run("t1", "nope") == 0
@pytest.mark.anyio
async def test_delete_nonexistent_thread_for_run_returns_zero(self, store):
assert await store.delete_by_run("nope", "r1") == 0
# -- Edge cases --
class TestEdgeCases:
@pytest.mark.anyio
async def test_empty_thread_list_messages(self, store):
assert await store.list_messages("empty") == []
@pytest.mark.anyio
async def test_empty_run_list_events(self, store):
assert await store.list_events("empty", "r1") == []
@pytest.mark.anyio
async def test_empty_thread_count_messages(self, store):
assert await store.count_messages("empty") == 0
# -- DB-specific tests --
class TestDbRunEventStore:
"""Tests for DbRunEventStore with temp SQLite."""
@pytest.mark.anyio
async def test_basic_crud(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
assert r["seq"] == 1
r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello")
assert r2["seq"] == 2
messages = await s.list_messages("t1")
assert len(messages) == 2
count = await s.count_messages("t1")
assert count == 2
await close_engine()
@pytest.mark.anyio
async def test_trace_content_truncation(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory(), max_trace_content=100)
long = "x" * 200
r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long)
assert len(r["content"]) == 100
assert r["metadata"].get("content_truncated") is True
# message content NOT truncated
m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long)
assert len(m["content"]) == 200
await close_engine()
@pytest.mark.anyio
async def test_pagination(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
for i in range(10):
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
# before_seq
msgs = await s.list_messages("t1", before_seq=6, limit=3)
assert [m["seq"] for m in msgs] == [3, 4, 5]
# after_seq
msgs = await s.list_messages("t1", after_seq=7, limit=3)
assert [m["seq"] for m in msgs] == [8, 9, 10]
# default (latest)
msgs = await s.list_messages("t1", limit=3)
assert [m["seq"] for m in msgs] == [8, 9, 10]
await close_engine()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
c = await s.delete_by_run("t1", "r2")
assert c == 1
assert await s.count_messages("t1") == 1
c = await s.delete_by_thread("t1")
assert c == 1
assert await s.count_messages("t1") == 0
await close_engine()
@pytest.mark.anyio
async def test_put_batch_seq_continuity(self, tmp_path):
"""Batch write produces continuous seq values with no gaps."""
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"} for _ in range(50)]
results = await s.put_batch(events)
seqs = [r["seq"] for r in results]
assert seqs == list(range(1, 51))
await close_engine()
# -- Factory tests --
class TestMakeRunEventStore:
"""Tests for the make_run_event_store factory function."""
@pytest.mark.anyio
async def test_memory_backend_default(self):
from deerflow.runtime.events.store import make_run_event_store
store = make_run_event_store(None)
assert type(store).__name__ == "MemoryRunEventStore"
@pytest.mark.anyio
async def test_memory_backend_explicit(self):
from unittest.mock import MagicMock
from deerflow.runtime.events.store import make_run_event_store
config = MagicMock()
config.backend = "memory"
store = make_run_event_store(config)
assert type(store).__name__ == "MemoryRunEventStore"
@pytest.mark.anyio
async def test_db_backend_with_engine(self, tmp_path):
from unittest.mock import MagicMock
from deerflow.persistence.engine import close_engine, init_engine
from deerflow.runtime.events.store import make_run_event_store
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
config = MagicMock()
config.backend = "db"
config.max_trace_content = 10240
store = make_run_event_store(config)
assert type(store).__name__ == "DbRunEventStore"
await close_engine()
@pytest.mark.anyio
async def test_db_backend_no_engine_falls_back(self):
"""db backend without engine falls back to memory."""
from unittest.mock import MagicMock
from deerflow.persistence.engine import close_engine, init_engine
from deerflow.runtime.events.store import make_run_event_store
await init_engine("memory") # no engine created
config = MagicMock()
config.backend = "db"
store = make_run_event_store(config)
assert type(store).__name__ == "MemoryRunEventStore"
await close_engine()
@pytest.mark.anyio
async def test_jsonl_backend(self):
from unittest.mock import MagicMock
from deerflow.runtime.events.store import make_run_event_store
config = MagicMock()
config.backend = "jsonl"
store = make_run_event_store(config)
assert type(store).__name__ == "JsonlRunEventStore"
@pytest.mark.anyio
async def test_unknown_backend_raises(self):
from unittest.mock import MagicMock
from deerflow.runtime.events.store import make_run_event_store
config = MagicMock()
config.backend = "redis"
with pytest.raises(ValueError, match="Unknown"):
make_run_event_store(config)
# -- JSONL-specific tests --
class TestJsonlRunEventStore:
@pytest.mark.anyio
async def test_basic_crud(self, tmp_path):
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
assert r["seq"] == 1
messages = await s.list_messages("t1")
assert len(messages) == 1
@pytest.mark.anyio
async def test_file_at_correct_path(self, tmp_path):
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists()
@pytest.mark.anyio
async def test_cross_run_messages(self, tmp_path):
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
messages = await s.list_messages("t1")
assert len(messages) == 2
assert [m["seq"] for m in messages] == [1, 2]
@pytest.mark.anyio
async def test_delete_by_run(self, tmp_path):
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
c = await s.delete_by_run("t1", "r2")
assert c == 1
assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists()
assert await s.count_messages("t1") == 1

View File

@ -1,107 +0,0 @@
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
@pytest.fixture
def base_store():
return MemoryRunEventStore()
@pytest.mark.anyio
async def test_list_messages_by_run_default_returns_all(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message", content=f"msg-a-{i}",
)
for i in range(3):
await store.put(
thread_id="t1", run_id="run-b",
event_type="human_message", category="message", content=f"msg-b-{i}",
)
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
msgs = await store.list_messages_by_run("t1", "run-a")
assert len(msgs) == 7
assert all(m["category"] == "message" for m in msgs)
assert all(m["run_id"] == "run-a" for m in msgs)
@pytest.mark.anyio
async def test_list_messages_by_run_with_limit(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message", content=f"msg-a-{i}",
)
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
assert len(msgs) == 3
seqs = [m["seq"] for m in msgs]
assert seqs == sorted(seqs)
@pytest.mark.anyio
async def test_list_messages_by_run_after_seq(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message", content=f"msg-a-{i}",
)
all_msgs = await store.list_messages_by_run("t1", "run-a")
cursor_seq = all_msgs[2]["seq"]
msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50)
assert all(m["seq"] > cursor_seq for m in msgs)
assert len(msgs) == 4
@pytest.mark.anyio
async def test_list_messages_by_run_before_seq(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message" if i % 2 == 0 else "ai_message",
category="message", content=f"msg-a-{i}",
)
all_msgs = await store.list_messages_by_run("t1", "run-a")
cursor_seq = all_msgs[4]["seq"]
msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50)
assert all(m["seq"] < cursor_seq for m in msgs)
assert len(msgs) == 4
@pytest.mark.anyio
async def test_list_messages_by_run_does_not_include_other_run(base_store):
store = base_store
for i in range(7):
await store.put(
thread_id="t1", run_id="run-a",
event_type="human_message", category="message", content=f"msg-a-{i}",
)
for i in range(3):
await store.put(
thread_id="t1", run_id="run-b",
event_type="human_message", category="message", content=f"msg-b-{i}",
)
msgs = await store.list_messages_by_run("t1", "run-b")
assert len(msgs) == 3
assert all(m["run_id"] == "run-b" for m in msgs)
@pytest.mark.anyio
async def test_list_messages_by_run_empty_run(base_store):
store = base_store
msgs = await store.list_messages_by_run("t1", "nonexistent")
assert msgs == []

View File

@ -1,385 +0,0 @@
"""Tests for RunJournal callback handler.
Uses MemoryRunEventStore as the backend for direct event inspection.
"""
import asyncio
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
from deerflow.runtime.journal import RunJournal
@pytest.fixture
def journal_setup():
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, flush_threshold=100)
return j, store
def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
"""Create a mock LLM response with a message.
model_dump() returns checkpoint-aligned format matching real AIMessage.
"""
msg = MagicMock()
msg.type = "ai"
msg.content = content
msg.id = f"msg-{id(msg)}"
msg.tool_calls = tool_calls or []
msg.invalid_tool_calls = []
msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage
msg.additional_kwargs = additional_kwargs or {}
msg.name = None
# model_dump returns checkpoint-aligned format
msg.model_dump.return_value = {
"content": content,
"additional_kwargs": additional_kwargs or {},
"response_metadata": {"model_name": "test-model"},
"type": "ai",
"name": None,
"id": msg.id,
"tool_calls": tool_calls or [],
"invalid_tool_calls": [],
"usage_metadata": usage,
}
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
return response
class TestLlmCallbacks:
@pytest.mark.anyio
async def test_on_llm_end_produces_trace_event(self, journal_setup):
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
await j.flush()
events = await store.list_events("t1", "r1")
trace_events = [e for e in events if e["event_type"] == "llm.ai.response"]
assert len(trace_events) == 1
assert trace_events[0]["category"] == "message"
@pytest.mark.anyio
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "llm.ai.response"
# Content is checkpoint-aligned model_dump format
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Answer"
@pytest.mark.anyio
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
"""LLM response with pending tool_calls emits llm.ai.response with tool_calls in content."""
j, store = journal_setup
run_id = uuid4()
j.on_llm_end(
_make_llm_response("Let me search", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]),
run_id=run_id,
parent_run_id=None,
tags=["lead_agent"],
)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "llm.ai.response"
assert len(messages[0]["content"]["tool_calls"]) == 1
@pytest.mark.anyio
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, parent_run_id=None, tags=["subagent:research"])
await j.flush()
messages = await store.list_messages("t1")
# subagent responses still emit llm.ai.response with category="message"
assert len(messages) == 1
@pytest.mark.anyio
async def test_token_accumulation(self, journal_setup):
j, store = journal_setup
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
assert j._total_input_tokens == 30
assert j._total_output_tokens == 15
assert j._total_tokens == 45
assert j._llm_call_count == 2
@pytest.mark.anyio
async def test_total_tokens_computed_from_input_output(self, journal_setup):
"""If total_tokens is 0, it should be computed from input + output."""
j, store = journal_setup
j.on_llm_end(
_make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
assert j._total_tokens == 150
@pytest.mark.anyio
async def test_caller_token_classification(self, journal_setup):
j, store = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarization"])
# token tracking not broken by caller type
assert j._total_tokens == 45
assert j._llm_call_count == 3
@pytest.mark.anyio
async def test_usage_metadata_none_no_crash(self, journal_setup):
j, store = journal_setup
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
@pytest.mark.anyio
async def test_latency_tracking(self, journal_setup):
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
await j.flush()
events = await store.list_events("t1", "r1")
llm_resp = [e for e in events if e["event_type"] == "llm.ai.response"][0]
assert "latency_ms" in llm_resp["metadata"]
assert llm_resp["metadata"]["latency_ms"] is not None
class TestLifecycleCallbacks:
@pytest.mark.anyio
async def test_chain_start_end_produce_trace_events(self, journal_setup):
j, store = journal_setup
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
j.on_chain_end({}, run_id=uuid4())
await asyncio.sleep(0.05)
await j.flush()
events = await store.list_events("t1", "r1")
types = {e["event_type"] for e in events}
assert "run.start" in types
assert "run.end" in types
@pytest.mark.anyio
async def test_nested_chain_no_run_start(self, journal_setup):
"""Nested chains (parent_run_id set) should NOT produce run.start."""
j, store = journal_setup
parent_id = uuid4()
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
j.on_chain_end({}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert not any(e["event_type"] == "run.start" for e in events)
class TestToolCallbacks:
@pytest.mark.anyio
async def test_tool_end_with_tool_message(self, journal_setup):
"""on_tool_end with a ToolMessage stores it as llm.tool.result."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
tool_msg = ToolMessage(content="results", tool_call_id="call_1", name="web_search")
j.on_tool_end(tool_msg, run_id=uuid4())
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "llm.tool.result"
assert messages[0]["content"]["type"] == "tool"
@pytest.mark.anyio
async def test_tool_end_with_command_unwraps_tool_message(self, journal_setup):
"""on_tool_end with Command(update={'messages':[ToolMessage]}) unwraps inner message."""
from langchain_core.messages import ToolMessage
from langgraph.types import Command
j, store = journal_setup
inner = ToolMessage(content="file list", tool_call_id="call_2", name="present_files")
cmd = Command(update={"messages": [inner]})
j.on_tool_end(cmd, run_id=uuid4())
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "llm.tool.result"
assert messages[0]["content"]["content"] == "file list"
@pytest.mark.anyio
async def test_on_tool_error_no_crash(self, journal_setup):
"""on_tool_error should not crash (no event emitted by default)."""
j, store = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
await j.flush()
# Base implementation does not emit tool_error — just verify no crash
events = await store.list_events("t1", "r1")
assert isinstance(events, list)
class TestCustomEvents:
@pytest.mark.anyio
async def test_on_custom_event_not_implemented(self, journal_setup):
"""RunJournal does not implement on_custom_event — no crash expected."""
j, store = journal_setup
# BaseCallbackHandler.on_custom_event is a no-op by default
j.on_custom_event("task_running", {"task_id": "t1"}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert isinstance(events, list)
class TestBufferFlush:
@pytest.mark.anyio
async def test_flush_threshold(self, journal_setup):
j, store = journal_setup
j._flush_threshold = 2
# Each on_llm_end emits 1 event
j.on_llm_end(_make_llm_response("A"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
assert len(j._buffer) == 1
j.on_llm_end(_make_llm_response("B"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
# At threshold the buffer should have been flushed asynchronously
await asyncio.sleep(0.1)
events = await store.list_events("t1", "r1")
assert len(events) >= 2
@pytest.mark.anyio
async def test_events_retained_when_no_loop(self, journal_setup):
"""Events buffered in a sync (no-loop) context should survive
until the async flush() in the finally block."""
j, store = journal_setup
j._flush_threshold = 1
original = asyncio.get_running_loop
def no_loop():
raise RuntimeError("no running event loop")
asyncio.get_running_loop = no_loop
try:
j._put(event_type="llm.ai.response", category="message", content="test")
finally:
asyncio.get_running_loop = original
assert len(j._buffer) == 1
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "llm.ai.response" for e in events)
class TestIdentifyCaller:
def test_lead_agent_tag(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller(["lead_agent"]) == "lead_agent"
def test_subagent_tag(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller(["subagent:research"]) == "subagent:research"
def test_middleware_tag(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller(["middleware:summarization"]) == "middleware:summarization"
def test_no_tags_returns_lead_agent(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller([]) == "lead_agent"
assert j._identify_caller(None) == "lead_agent"
class TestChainErrorCallback:
@pytest.mark.anyio
async def test_on_chain_error_writes_run_error(self, journal_setup):
j, store = journal_setup
j.on_chain_error(ValueError("boom"), run_id=uuid4())
await asyncio.sleep(0.05)
await j.flush()
events = await store.list_events("t1", "r1")
error_events = [e for e in events if e["event_type"] == "run.error"]
assert len(error_events) == 1
assert "boom" in error_events[0]["content"]
assert error_events[0]["metadata"]["error_type"] == "ValueError"
class TestTokenTrackingDisabled:
@pytest.mark.anyio
async def test_track_token_usage_false(self):
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
j.on_llm_end(
_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["total_tokens"] == 0
assert data["llm_call_count"] == 0
class TestConvenienceFields:
@pytest.mark.anyio
async def test_first_human_message_via_set(self, journal_setup):
j, _ = journal_setup
j.set_first_human_message("What is AI?")
data = j.get_completion_data()
assert data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_get_completion_data(self, journal_setup):
j, _ = journal_setup
j._total_tokens = 100
j._msg_count = 5
data = j.get_completion_data()
assert data["total_tokens"] == 100
assert data["message_count"] == 5
class TestMiddlewareEvents:
@pytest.mark.anyio
async def test_record_middleware_uses_middleware_category(self, journal_setup):
j, store = journal_setup
j.record_middleware(
"title",
name="TitleMiddleware",
hook="after_model",
action="generate_title",
changes={"title": "Test Title", "thread_id": "t1"},
)
await j.flush()
events = await store.list_events("t1", "r1")
mw_events = [e for e in events if e["event_type"] == "middleware:title"]
assert len(mw_events) == 1
assert mw_events[0]["category"] == "middleware"
assert mw_events[0]["content"]["name"] == "TitleMiddleware"
assert mw_events[0]["content"]["hook"] == "after_model"
assert mw_events[0]["content"]["action"] == "generate_title"
assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
@pytest.mark.anyio
async def test_middleware_tag_variants(self, journal_setup):
"""Different middleware tags produce distinct event_types."""
j, store = journal_setup
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
await j.flush()
events = await store.list_events("t1", "r1")
event_types = {e["event_type"] for e in events}
assert "middleware:title" in event_types
assert "middleware:guardrail" in event_types

View File

@ -1,196 +0,0 @@
"""Tests for RunRepository (SQLAlchemy-backed RunStore).
Uses a temp SQLite DB to test ORM-backed CRUD operations.
"""
import pytest
from deerflow.persistence.run import RunRepository
async def _make_repo(tmp_path):
from deerflow.persistence.engine import get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return RunRepository(get_session_factory())
async def _cleanup():
from deerflow.persistence.engine import close_engine
await close_engine()
class TestRunRepository:
@pytest.mark.anyio
async def test_put_and_get(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="pending")
row = await repo.get("r1")
assert row is not None
assert row["run_id"] == "r1"
assert row["thread_id"] == "t1"
assert row["status"] == "pending"
await _cleanup()
@pytest.mark.anyio
async def test_get_missing_returns_none(self, tmp_path):
repo = await _make_repo(tmp_path)
assert await repo.get("nope") is None
await _cleanup()
@pytest.mark.anyio
async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.update_status("r1", "running")
row = await repo.get("r1")
assert row["status"] == "running"
await _cleanup()
@pytest.mark.anyio
async def test_update_status_with_error(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.update_status("r1", "error", error="boom")
row = await repo.get("r1")
assert row["status"] == "error"
assert row["error"] == "boom"
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.put("r2", thread_id="t1")
await repo.put("r3", thread_id="t2")
rows = await repo.list_by_thread("t1")
assert len(rows) == 2
assert all(r["thread_id"] == "t1" for r in rows)
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_owner_filter(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", user_id="alice")
await repo.put("r2", thread_id="t1", user_id="bob")
rows = await repo.list_by_thread("t1", user_id="alice")
assert len(rows) == 1
assert rows[0]["user_id"] == "alice"
await _cleanup()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.delete("r1")
assert await repo.get("r1") is None
await _cleanup()
@pytest.mark.anyio
async def test_delete_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.delete("nope") # should not raise
await _cleanup()
@pytest.mark.anyio
async def test_list_pending(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="pending")
await repo.put("r2", thread_id="t1", status="running")
await repo.put("r3", thread_id="t2", status="pending")
pending = await repo.list_pending()
assert len(pending) == 2
assert all(r["status"] == "pending" for r in pending)
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion(
"r1",
status="success",
total_input_tokens=100,
total_output_tokens=50,
total_tokens=150,
llm_call_count=2,
lead_agent_tokens=120,
subagent_tokens=20,
middleware_tokens=10,
message_count=3,
last_ai_message="The answer is 42",
first_human_message="What is the meaning?",
)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 150
assert row["llm_call_count"] == 2
assert row["lead_agent_tokens"] == 120
assert row["message_count"] == 3
assert row["last_ai_message"] == "The answer is 42"
assert row["first_human_message"] == "What is the meaning?"
await _cleanup()
@pytest.mark.anyio
async def test_metadata_preserved(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", metadata={"key": "value"})
row = await repo.get("r1")
assert row["metadata"] == {"key": "value"}
await _cleanup()
@pytest.mark.anyio
async def test_kwargs_with_non_serializable(self, tmp_path):
"""kwargs containing non-JSON-serializable objects should be safely handled."""
repo = await _make_repo(tmp_path)
class Dummy:
pass
await repo.put("r1", thread_id="t1", kwargs={"obj": Dummy()})
row = await repo.get("r1")
assert "obj" in row["kwargs"]
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion_preserves_existing_fields(self, tmp_path):
"""update_run_completion does not overwrite thread_id or assistant_id."""
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", assistant_id="agent1", status="running")
await repo.update_run_completion("r1", status="success", total_tokens=100)
row = await repo.get("r1")
assert row["thread_id"] == "t1"
assert row["assistant_id"] == "agent1"
assert row["total_tokens"] == 100
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first."""
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", created_at="2024-01-01T00:00:00+00:00")
await repo.put("r2", thread_id="t1", created_at="2024-01-02T00:00:00+00:00")
rows = await repo.list_by_thread("t1")
assert rows[0]["run_id"] == "r2"
assert rows[1]["run_id"] == "r1"
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_limit(self, tmp_path):
repo = await _make_repo(tmp_path)
for i in range(5):
await repo.put(f"r{i}", thread_id="t1")
rows = await repo.list_by_thread("t1", limit=2)
assert len(rows) == 2
await _cleanup()
@pytest.mark.anyio
async def test_owner_none_returns_all(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", user_id="alice")
await repo.put("r2", thread_id="t1", user_id="bob")
rows = await repo.list_by_thread("t1", user_id=None)
assert len(rows) == 2
await _cleanup()

View File

@ -1,214 +0,0 @@
from unittest.mock import AsyncMock, call
import pytest
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint
class FakeCheckpointer:
def __init__(self, *, put_result):
self.adelete_thread = AsyncMock()
self.aput = AsyncMock(return_value=put_result)
self.aput_writes = AsyncMock()
@pytest.mark.anyio
async def test_rollback_restores_snapshot_without_deleting_thread():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {
"id": "ckpt-1",
"channel_versions": {"messages": 3},
"channel_values": {"messages": ["before"]},
},
"metadata": {"source": "input"},
"pending_writes": [
("task-a", "messages", {"content": "first"}),
("task-a", "status", "done"),
("task-b", "events", {"type": "tool"}),
],
},
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{
"id": "ckpt-1",
"channel_versions": {"messages": 3},
"channel_values": {"messages": ["before"]},
},
{"source": "input"},
{"messages": 3},
)
assert checkpointer.aput_writes.await_args_list == [
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
[("messages", {"content": "first"}), ("status", "done")],
task_id="task-a",
),
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
[("events", {"type": "tool"})],
task_id="task-b",
),
]
@pytest.mark.anyio
async def test_rollback_deletes_thread_when_no_snapshot_exists():
checkpointer = FakeCheckpointer(put_result=None)
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id=None,
pre_run_snapshot=None,
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_awaited_once_with("thread-1")
checkpointer.aput.assert_not_awaited()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_raises_when_restore_config_has_no_checkpoint_id():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}})
with pytest.raises(RuntimeError, match="did not return checkpoint_id"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [("task-a", "messages", "value")],
},
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": None,
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [],
},
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{"id": "ckpt-1", "channel_versions": {}},
{},
{},
)
@pytest.mark.anyio
async def test_rollback_raises_on_malformed_pending_write_not_a_tuple():
"""pending_writes containing a non-3-tuple item should raise RuntimeError."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", "messages", "valid"), # valid
["only", "two"], # malformed: only 2 elements
],
},
snapshot_capture_failed=False,
)
# aput succeeded but aput_writes should not be called due to malformed data
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_raises_on_malformed_pending_write_non_string_channel():
"""pending_writes containing a non-string channel should raise RuntimeError."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", 123, "value"), # malformed: channel is not a string
],
},
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_propagates_aput_writes_failure():
"""If aput_writes fails, the exception should propagate (not be swallowed)."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
# Simulate aput_writes failure
checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost")
with pytest.raises(RuntimeError, match="Database connection lost"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", "messages", "value"),
],
},
snapshot_capture_failed=False,
)
# aput succeeded, aput_writes was called but failed
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_awaited_once()

View File

@ -1,243 +0,0 @@
"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.gateway.routers import runs
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_app(run_store=None, event_store=None, feedback_repo=None):
"""Build a test FastAPI app with stub auth and mocked state."""
app = make_authed_test_app()
app.include_router(runs.router)
if run_store is not None:
app.state.run_store = run_store
if event_store is not None:
app.state.run_event_store = event_store
if feedback_repo is not None:
app.state.feedback_repo = feedback_repo
return app
def _make_run_store(run_record: dict | None):
"""Return an AsyncMock run store whose get() returns run_record."""
store = MagicMock()
store.get = AsyncMock(return_value=run_record)
return store
def _make_event_store(rows: list[dict]):
"""Return an AsyncMock event store whose list_messages_by_run() returns rows."""
store = MagicMock()
store.list_messages_by_run = AsyncMock(return_value=rows)
return store
def _make_message(seq: int) -> dict:
return {"seq": seq, "event_type": "on_chat_model_stream", "category": "message", "content": f"msg-{seq}"}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_run_messages_returns_envelope():
"""GET /api/runs/{run_id}/messages returns {data: [...], has_more: bool}."""
rows = [_make_message(i) for i in range(1, 4)]
run_record = {"run_id": "run-1", "thread_id": "thread-1"}
app = _make_app(
run_store=_make_run_store(run_record),
event_store=_make_event_store(rows),
)
with TestClient(app) as client:
response = client.get("/api/runs/run-1/messages")
assert response.status_code == 200
body = response.json()
assert "data" in body
assert "has_more" in body
assert body["has_more"] is False
assert len(body["data"]) == 3
def test_run_messages_404_when_run_not_found():
"""Returns 404 when the run store returns None."""
app = _make_app(
run_store=_make_run_store(None),
event_store=_make_event_store([]),
)
with TestClient(app) as client:
response = client.get("/api/runs/missing-run/messages")
assert response.status_code == 404
assert "missing-run" in response.json()["detail"]
def test_run_messages_has_more_true_when_extra_row_returned():
"""has_more=True when event store returns limit+1 rows."""
# Default limit is 50; provide 51 rows
rows = [_make_message(i) for i in range(1, 52)] # 51 rows
run_record = {"run_id": "run-2", "thread_id": "thread-2"}
app = _make_app(
run_store=_make_run_store(run_record),
event_store=_make_event_store(rows),
)
with TestClient(app) as client:
response = client.get("/api/runs/run-2/messages")
assert response.status_code == 200
body = response.json()
assert body["has_more"] is True
assert len(body["data"]) == 50 # trimmed to limit
def test_run_messages_passes_after_seq_to_event_store():
"""after_seq query param is forwarded to event_store.list_messages_by_run."""
rows = [_make_message(10)]
run_record = {"run_id": "run-3", "thread_id": "thread-3"}
event_store = _make_event_store(rows)
app = _make_app(
run_store=_make_run_store(run_record),
event_store=event_store,
)
with TestClient(app) as client:
response = client.get("/api/runs/run-3/messages?after_seq=5")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-3", "run-3",
limit=51, # default limit(50) + 1
before_seq=None,
after_seq=5,
)
def test_run_messages_respects_custom_limit():
"""Custom limit is respected and capped at 200."""
rows = [_make_message(i) for i in range(1, 6)]
run_record = {"run_id": "run-4", "thread_id": "thread-4"}
event_store = _make_event_store(rows)
app = _make_app(
run_store=_make_run_store(run_record),
event_store=event_store,
)
with TestClient(app) as client:
response = client.get("/api/runs/run-4/messages?limit=10")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-4", "run-4",
limit=11, # 10 + 1
before_seq=None,
after_seq=None,
)
def test_run_messages_passes_before_seq_to_event_store():
"""before_seq query param is forwarded to event_store.list_messages_by_run."""
rows = [_make_message(3)]
run_record = {"run_id": "run-5", "thread_id": "thread-5"}
event_store = _make_event_store(rows)
app = _make_app(
run_store=_make_run_store(run_record),
event_store=event_store,
)
with TestClient(app) as client:
response = client.get("/api/runs/run-5/messages?before_seq=10")
assert response.status_code == 200
event_store.list_messages_by_run.assert_awaited_once_with(
"thread-5", "run-5",
limit=51,
before_seq=10,
after_seq=None,
)
def test_run_messages_empty_data():
"""Returns empty data list when no messages exist."""
run_record = {"run_id": "run-6", "thread_id": "thread-6"}
app = _make_app(
run_store=_make_run_store(run_record),
event_store=_make_event_store([]),
)
with TestClient(app) as client:
response = client.get("/api/runs/run-6/messages")
assert response.status_code == 200
body = response.json()
assert body["data"] == []
assert body["has_more"] is False
def _make_feedback_repo(rows: list[dict]):
"""Return an AsyncMock feedback repo whose list_by_run() returns rows."""
repo = MagicMock()
repo.list_by_run = AsyncMock(return_value=rows)
return repo
def _make_feedback(run_id: str, idx: int) -> dict:
return {"id": f"fb-{idx}", "run_id": run_id, "thread_id": "thread-x", "value": "up"}
# ---------------------------------------------------------------------------
# TestRunFeedback
# ---------------------------------------------------------------------------
class TestRunFeedback:
def test_returns_list_of_feedback_dicts(self):
"""GET /api/runs/{run_id}/feedback returns a list of feedback dicts."""
run_record = {"run_id": "run-fb-1", "thread_id": "thread-fb-1"}
rows = [_make_feedback("run-fb-1", i) for i in range(3)]
app = _make_app(
run_store=_make_run_store(run_record),
feedback_repo=_make_feedback_repo(rows),
)
with TestClient(app) as client:
response = client.get("/api/runs/run-fb-1/feedback")
assert response.status_code == 200
body = response.json()
assert isinstance(body, list)
assert len(body) == 3
def test_404_when_run_not_found(self):
"""Returns 404 when run store returns None."""
app = _make_app(
run_store=_make_run_store(None),
feedback_repo=_make_feedback_repo([]),
)
with TestClient(app) as client:
response = client.get("/api/runs/missing-run/feedback")
assert response.status_code == 404
assert "missing-run" in response.json()["detail"]
def test_empty_list_when_no_feedback(self):
"""Returns empty list when no feedback exists for the run."""
run_record = {"run_id": "run-fb-2", "thread_id": "thread-fb-2"}
app = _make_app(
run_store=_make_run_store(run_record),
feedback_repo=_make_feedback_repo([]),
)
with TestClient(app) as client:
response = client.get("/api/runs/run-fb-2/feedback")
assert response.status_code == 200
assert response.json() == []
def test_503_when_feedback_repo_not_configured(self):
"""Returns 503 when feedback_repo is None (no DB configured)."""
run_record = {"run_id": "run-fb-3", "thread_id": "thread-fb-3"}
app = _make_app(
run_store=_make_run_store(run_record),
)
# Explicitly set feedback_repo to None to simulate missing DB
app.state.feedback_repo = None
with TestClient(app) as client:
response = client.get("/api/runs/run-fb-3/feedback")
assert response.status_code == 503

View File

@ -1,178 +0,0 @@
"""Tests for ThreadMetaRepository (SQLAlchemy-backed)."""
import pytest
from deerflow.persistence.thread_meta import ThreadMetaRepository
async def _make_repo(tmp_path):
from deerflow.persistence.engine import get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return ThreadMetaRepository(get_session_factory())
async def _cleanup():
from deerflow.persistence.engine import close_engine
await close_engine()
class TestThreadMetaRepository:
@pytest.mark.anyio
async def test_create_and_get(self, tmp_path):
repo = await _make_repo(tmp_path)
record = await repo.create("t1")
assert record["thread_id"] == "t1"
assert record["status"] == "idle"
assert "created_at" in record
fetched = await repo.get("t1")
assert fetched is not None
assert fetched["thread_id"] == "t1"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_assistant_id(self, tmp_path):
repo = await _make_repo(tmp_path)
record = await repo.create("t1", assistant_id="agent1")
assert record["assistant_id"] == "agent1"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_owner_and_display_name(self, tmp_path):
repo = await _make_repo(tmp_path)
record = await repo.create("t1", user_id="user1", display_name="My Thread")
assert record["user_id"] == "user1"
assert record["display_name"] == "My Thread"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_metadata(self, tmp_path):
repo = await _make_repo(tmp_path)
record = await repo.create("t1", metadata={"key": "value"})
assert record["metadata"] == {"key": "value"}
await _cleanup()
@pytest.mark.anyio
async def test_get_nonexistent(self, tmp_path):
repo = await _make_repo(tmp_path)
assert await repo.get("nonexistent") is None
await _cleanup()
@pytest.mark.anyio
async def test_check_access_no_record_allows(self, tmp_path):
repo = await _make_repo(tmp_path)
assert await repo.check_access("unknown", "user1") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_owner_matches(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user1") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_owner_mismatch(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user2") is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_no_owner_allows_all(self, tmp_path):
repo = await _make_repo(tmp_path)
# Explicit user_id=None to bypass the new AUTO default that
# would otherwise pick up the test user from the autouse fixture.
await repo.create("t1", user_id=None)
assert await repo.check_access("t1", "anyone") is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_missing_row_denied(self, tmp_path):
"""require_existing=True flips the missing-row case to *denied*.
Closes the delete-idempotence cross-user gap: after a thread is
deleted, the row is gone, and the permissive default would let any
caller "claim" it as untracked. The strict mode demands a row.
"""
repo = await _make_repo(tmp_path)
assert await repo.check_access("never-existed", "user1", require_existing=True) is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user1", require_existing=True) is True
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user2", require_existing=True) is False
await _cleanup()
@pytest.mark.anyio
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
"""Even in strict mode, a row with NULL user_id stays shared.
The strict flag tightens the *missing row* case, not the *shared
row* case legacy pre-auth rows that survived a clean migration
without an owner are still everyone's.
"""
repo = await _make_repo(tmp_path)
await repo.create("t1", user_id=None)
assert await repo.check_access("t1", "anyone", require_existing=True) is True
await _cleanup()
@pytest.mark.anyio
async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1")
await repo.update_status("t1", "busy")
record = await repo.get("t1")
assert record["status"] == "busy"
await _cleanup()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1")
await repo.delete("t1")
assert await repo.get("t1") is None
await _cleanup()
@pytest.mark.anyio
async def test_delete_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.delete("nonexistent") # should not raise
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_merges(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1", metadata={"a": 1, "b": 2})
await repo.update_metadata("t1", {"b": 99, "c": 3})
record = await repo.get("t1")
# Existing key preserved, overlapping key overwritten, new key added
assert record["metadata"] == {"a": 1, "b": 99, "c": 3}
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_on_empty(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1")
await repo.update_metadata("t1", {"k": "v"})
record = await repo.get("t1")
assert record["metadata"] == {"k": "v"}
await _cleanup()
@pytest.mark.anyio
async def test_update_metadata_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
await _cleanup()

View File

@ -1,110 +0,0 @@
"""Tests for runtime.user_context — contextvar three-state semantics.
These tests opt out of the autouse contextvar fixture (added in
commit 6) because they explicitly test the cases where the contextvar
is set or unset.
"""
from types import SimpleNamespace
import pytest
from deerflow.runtime.user_context import (
CurrentUser,
DEFAULT_USER_ID,
get_current_user,
get_effective_user_id,
require_current_user,
reset_current_user,
set_current_user,
)
@pytest.mark.no_auto_user
def test_default_is_none():
"""Before any set, contextvar returns None."""
assert get_current_user() is None
@pytest.mark.no_auto_user
def test_set_and_reset_roundtrip():
"""set_current_user returns a token that reset restores."""
user = SimpleNamespace(id="user-1")
token = set_current_user(user)
try:
assert get_current_user() is user
finally:
reset_current_user(token)
assert get_current_user() is None
@pytest.mark.no_auto_user
def test_require_current_user_raises_when_unset():
"""require_current_user raises RuntimeError if contextvar is unset."""
assert get_current_user() is None
with pytest.raises(RuntimeError, match="without user context"):
require_current_user()
@pytest.mark.no_auto_user
def test_require_current_user_returns_user_when_set():
"""require_current_user returns the user when contextvar is set."""
user = SimpleNamespace(id="user-2")
token = set_current_user(user)
try:
assert require_current_user() is user
finally:
reset_current_user(token)
@pytest.mark.no_auto_user
def test_protocol_accepts_duck_typed():
"""CurrentUser is a runtime_checkable Protocol matching any .id-bearing object."""
user = SimpleNamespace(id="user-3")
assert isinstance(user, CurrentUser)
@pytest.mark.no_auto_user
def test_protocol_rejects_no_id():
"""Objects without .id do not satisfy CurrentUser Protocol."""
not_a_user = SimpleNamespace(email="no-id@example.com")
assert not isinstance(not_a_user, CurrentUser)
# ---------------------------------------------------------------------------
# get_effective_user_id / DEFAULT_USER_ID tests
# ---------------------------------------------------------------------------
def test_default_user_id_is_default():
assert DEFAULT_USER_ID == "default"
@pytest.mark.no_auto_user
def test_effective_user_id_returns_default_when_no_user():
"""No user in context -> fallback to DEFAULT_USER_ID."""
assert get_effective_user_id() == "default"
@pytest.mark.no_auto_user
def test_effective_user_id_returns_user_id_when_set():
user = SimpleNamespace(id="u-abc-123")
token = set_current_user(user)
try:
assert get_effective_user_id() == "u-abc-123"
finally:
reset_current_user(token)
@pytest.mark.no_auto_user
def test_effective_user_id_coerces_to_str():
"""User.id might be a UUID object; must come back as str."""
import uuid
uid = uuid.uuid4()
user = SimpleNamespace(id=uid)
token = set_current_user(user)
try:
assert get_effective_user_id() == str(uid)
finally:
reset_current_user(token)

View File

@ -1,12 +1,9 @@
"""Helpers for router-level tests that need a stubbed auth context.
"""Helpers for router-level tests that need an authenticated request.
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
ahead of every router, plus ``@require_permission(owner_check=True)``
decorators that read ``request.state.auth`` and call
``thread_store.check_access``. Router-level unit tests construct
**bare** FastAPI apps that include only one router they have neither
the auth middleware nor a real thread_store, so the decorators raise
401 (TestClient path) or ValueError (direct-call path).
The production gateway stamps ``request.user`` / ``request.auth`` in the
auth middleware, then route decorators read that authenticated context.
Router-level unit tests build very small FastAPI apps that include only
one router, so they need a lightweight stand-in for that middleware.
This module provides two surfaces:
@ -15,10 +12,9 @@ This module provides two surfaces:
request, plus a permissive ``thread_store`` mock on
``app.state``. Use from TestClient-based router tests.
2. :func:`call_unwrapped` invokes the underlying function bypassing
the ``@require_permission`` decorator chain by walking ``__wrapped__``.
Use from direct-call tests that previously imported the route
function and called it positionally.
2. :func:`call_unwrapped` invokes the underlying function by walking
``__wrapped__``. Use from direct-call tests that want to bypass the
route decorators entirely.
Both helpers are deliberately permissive: they never deny a request.
Tests that want to verify the *auth boundary itself* (e.g.
@ -37,8 +33,8 @@ from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.gateway.auth.models import User
from app.gateway.authz import AuthContext, Permissions
from app.plugins.auth.domain.models import User
from app.plugins.auth.authorization import AuthContext, Permissions
# Default permission set granted to the stub user. Mirrors `_ALL_PERMISSIONS`
# in authz.py — kept inline so the tests don't import a private symbol.
@ -63,12 +59,7 @@ def _make_stub_user() -> User:
class _StubAuthMiddleware(BaseHTTPMiddleware):
"""Stamp a fake user / AuthContext onto every request.
Mirrors what production ``AuthMiddleware`` does after the JWT decode
+ DB lookup short-circuit, so ``@require_permission`` finds an
authenticated context and skips its own re-authentication path.
"""
"""Stamp a fake user / AuthContext onto every request."""
def __init__(self, app: ASGIApp, user_factory: Callable[[], User]) -> None:
super().__init__(app)
@ -76,8 +67,11 @@ class _StubAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response:
user = self._user_factory()
auth_context = AuthContext(user=user, permissions=list(_STUB_PERMISSIONS))
request.scope["user"] = user
request.scope["auth"] = auth_context
request.state.user = user
request.state.auth = AuthContext(user=user, permissions=list(_STUB_PERMISSIONS))
request.state.auth = auth_context
return await call_next(request)
@ -93,9 +87,8 @@ def make_authed_test_app(
populated :class:`User`. Useful for cross-user isolation tests
that need a stable id across requests.
owner_check_passes: When True (default), ``thread_store.check_access``
returns True for every call so ``@require_permission(owner_check=True)``
never blocks the route under test. Pass False to verify that
permission failures surface correctly.
returns True for every call so owner-gated routes do not block
the handler under test. Pass False to verify denial paths.
Returns:
A ``FastAPI`` app with the stub middleware installed and
@ -121,12 +114,9 @@ def call_unwrapped(decorated: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.
"""Invoke the underlying function of a ``@require_permission``-decorated route.
``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all
the way down to the original handler, bypassing every authz +
require_auth wrapper. Use from tests that need to call route
functions directly (without TestClient) and don't want to construct
a fake ``Request`` just to satisfy the decorator. The ``ParamSpec``
propagates the wrapped route's signature so call sites still get
parameter checking despite the unwrapping.
the way down to the original handler. Use from tests that call route
functions directly and do not want to build a full request/middleware
stack.
"""
fn: Callable = decorated
while hasattr(fn, "__wrapped__"):

View File

@ -1,22 +1,22 @@
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
"""Tests for authentication module: JWT, password hashing, and auth context behavior."""
from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
from app.gateway.auth.models import User
from app.gateway.authz import (
from app.plugins.auth.authorization import (
AuthContext,
Permissions,
get_auth_context,
require_auth,
require_permission,
)
from app.plugins.auth.authorization.hooks import build_authz_hooks
from app.plugins.auth.domain import create_access_token, decode_token, hash_password, verify_password
from app.plugins.auth.domain.models import User
from store.persistence import MappedBase
# ── Password Hashing ────────────────────────────────────────────────────────
@ -67,7 +67,7 @@ def test_create_and_decode_token():
def test_decode_token_expired():
"""Expired token returns TokenError.EXPIRED."""
from app.gateway.auth.errors import TokenError
from app.plugins.auth.domain.errors import TokenError
user_id = str(uuid4())
# Create token that expires immediately
@ -78,7 +78,7 @@ def test_decode_token_expired():
def test_decode_token_invalid():
"""Invalid token returns TokenError."""
from app.gateway.auth.errors import TokenError
from app.plugins.auth.domain.errors import TokenError
assert isinstance(decode_token("not.a.valid.token"), TokenError)
assert isinstance(decode_token(""), TokenError)
@ -101,6 +101,8 @@ def test_auth_context_unauthenticated():
"""AuthContext with no user."""
ctx = AuthContext(user=None, permissions=[])
assert ctx.is_authenticated is False
assert ctx.principal_id is None
assert ctx.capabilities == ()
assert ctx.has_permission("threads", "read") is False
@ -109,6 +111,8 @@ def test_auth_context_authenticated_no_perms():
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[])
assert ctx.is_authenticated is True
assert ctx.principal_id == str(user.id)
assert ctx.capabilities == ()
assert ctx.has_permission("threads", "read") is False
@ -117,6 +121,7 @@ def test_auth_context_has_permission():
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
ctx = AuthContext(user=user, permissions=perms)
assert ctx.capabilities == tuple(perms)
assert ctx.has_permission("threads", "read") is True
assert ctx.has_permission("threads", "write") is True
assert ctx.has_permission("threads", "delete") is False
@ -162,79 +167,12 @@ def test_get_auth_context_set():
assert get_auth_context(mock_request) == ctx
# ── require_auth decorator ──────────────────────────────────────────────────
def test_register_app_sets_default_authz_hooks():
from app.gateway.registrar import register_app
app = register_app()
def test_require_auth_sets_auth_context():
"""require_auth sets auth context on request from cookie."""
from fastapi import Request
app = FastAPI()
@app.get("/test")
@require_auth
async def endpoint(request: Request):
ctx = get_auth_context(request)
return {"authenticated": ctx.is_authenticated}
with TestClient(app) as client:
# No cookie → anonymous
response = client.get("/test")
assert response.status_code == 200
assert response.json()["authenticated"] is False
def test_require_auth_requires_request_param():
"""require_auth raises ValueError if request parameter is missing."""
import asyncio
@require_auth
async def bad_endpoint(): # Missing `request` parameter
pass
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
asyncio.run(bad_endpoint())
# ── require_permission decorator ─────────────────────────────────────────────
def test_require_permission_requires_auth():
"""require_permission raises 401 when not authenticated."""
from fastapi import Request
app = FastAPI()
@app.get("/test")
@require_permission("threads", "read")
async def endpoint(request: Request):
return {"ok": True}
with TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 401
assert "Authentication required" in response.json()["detail"]
def test_require_permission_denies_wrong_permission():
"""User without required permission gets 403."""
from fastapi import Request
app = FastAPI()
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
@app.get("/test")
@require_permission("threads", "delete")
async def endpoint(request: Request):
return {"ok": True}
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 403
assert "Permission denied" in response.json()["detail"]
assert app.state.authz_hooks == build_authz_hooks()
# ── Weak JWT secret warning ──────────────────────────────────────────────────
@ -271,45 +209,55 @@ def test_sqlite_round_trip_new_fields():
import asyncio
import tempfile
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from app.plugins.auth.storage import DbUserRepository, UserCreate
async def _run() -> None:
from deerflow.persistence.engine import (
close_engine,
get_session_factory,
init_engine,
)
with tempfile.TemporaryDirectory() as tmpdir:
url = f"sqlite+aiosqlite:///{tmpdir}/scratch.db"
await init_engine("sqlite", url=url, sqlite_dir=tmpdir)
engine = create_async_engine(f"sqlite+aiosqlite:///{tmpdir}/scratch.db", future=True)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
try:
repo = SQLiteUserRepository(get_session_factory())
user = User(
email="setup@test.com",
password_hash="fakehash",
system_role="admin",
needs_setup=True,
token_version=3,
)
created = await repo.create_user(user)
async with session_factory() as session:
repo = DbUserRepository(session)
created = await repo.create_user(
UserCreate(
email="setup@test.com",
password_hash="fakehash",
system_role="admin",
needs_setup=True,
token_version=3,
)
)
await session.commit()
assert created.needs_setup is True
assert created.token_version == 3
fetched = await repo.get_user_by_email("setup@test.com")
async with session_factory() as session:
repo = DbUserRepository(session)
fetched = await repo.get_user_by_email("setup@test.com")
assert fetched is not None
assert fetched.needs_setup is True
assert fetched.token_version == 3
fetched.needs_setup = False
fetched.token_version = 4
await repo.update_user(fetched)
refetched = await repo.get_user_by_id(str(fetched.id))
updated = fetched.model_copy(update={"needs_setup": False, "token_version": 4})
async with session_factory() as session:
repo = DbUserRepository(session)
await repo.update_user(updated)
await session.commit()
async with session_factory() as session:
repo = DbUserRepository(session)
refetched = await repo.get_user_by_id(fetched.id)
assert refetched is not None
assert refetched.needs_setup is False
assert refetched.token_version == 4
finally:
await close_engine()
await engine.dispose()
asyncio.run(_run())
@ -320,49 +268,54 @@ def test_update_user_raises_when_row_concurrently_deleted(tmp_path):
Earlier the SQLite repo returned the input unchanged when the row was
missing, making a phantom success path that admin password reset
callers (`reset_admin`, `_ensure_admin_user`) would happily log as
'password reset'. The new contract: raise ``UserNotFoundError`` so
'password reset'. The new contract: raise ``LookupError`` so
a vanished row never looks like a successful update.
"""
import asyncio
import tempfile
from app.gateway.auth.repositories.base import UserNotFoundError
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from app.plugins.auth.storage import DbUserRepository, UserCreate
async def _run() -> None:
from deerflow.persistence.engine import (
close_engine,
get_session_factory,
init_engine,
)
from deerflow.persistence.user.model import UserRow
from app.plugins.auth.storage.models import User as UserModel
with tempfile.TemporaryDirectory() as d:
url = f"sqlite+aiosqlite:///{d}/scratch.db"
await init_engine("sqlite", url=url, sqlite_dir=d)
engine = create_async_engine(f"sqlite+aiosqlite:///{d}/scratch.db", future=True)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
sf = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
try:
sf = get_session_factory()
repo = SQLiteUserRepository(sf)
user = User(
email="ghost@test.com",
password_hash="fakehash",
system_role="user",
)
created = await repo.create_user(user)
async with sf() as session:
repo = DbUserRepository(session)
created = await repo.create_user(
UserCreate(
email="ghost@test.com",
password_hash="fakehash",
system_role="user",
)
)
await session.commit()
# Simulate "row vanished underneath us" by deleting the row
# via the raw ORM session, then attempt to update.
async with sf() as session:
row = await session.get(UserRow, str(created.id))
row = await session.get(UserModel, created.id)
assert row is not None
await session.delete(row)
await session.commit()
created.needs_setup = True
with pytest.raises(UserNotFoundError):
await repo.update_user(created)
updated = created.model_copy(update={"needs_setup": True})
async with sf() as session:
repo = DbUserRepository(session)
with pytest.raises(LookupError):
await repo.update_user(updated)
finally:
await close_engine()
await engine.dispose()
asyncio.run(_run())
@ -374,7 +327,7 @@ def test_jwt_encodes_ver():
"""JWT payload includes ver field."""
import os
from app.gateway.auth.errors import TokenError
from app.plugins.auth.domain.errors import TokenError
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(str(uuid4()), token_version=3)
@ -387,7 +340,7 @@ def test_jwt_default_ver_zero():
"""JWT ver defaults to 0."""
import os
from app.gateway.auth.errors import TokenError
from app.plugins.auth.domain.errors import TokenError
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(str(uuid4()))
@ -398,30 +351,34 @@ def test_jwt_default_ver_zero():
def test_token_version_mismatch_rejects():
"""Token with stale ver is rejected by get_current_user_from_request."""
import asyncio
import os
from types import SimpleNamespace
from app.plugins.auth.security.dependencies import get_current_user_from_request
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
user_id = str(uuid4())
token = create_access_token(user_id, token_version=0)
request = SimpleNamespace(
cookies={"access_token": token},
state=SimpleNamespace(
_auth_session=MagicMock(),
),
)
stale_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
request.state._auth_session.__aenter__ = AsyncMock(return_value=request.state._auth_session)
request.state._auth_session.__aexit__ = AsyncMock(return_value=None)
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
mock_request = MagicMock()
mock_request.cookies = {"access_token": token}
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
mock_provider = MagicMock()
mock_provider.get_user = AsyncMock(return_value=mock_user)
mock_provider_fn.return_value = mock_provider
from app.gateway.deps import get_current_user_from_request
with patch(
"app.plugins.auth.security.dependencies.DbUserRepository.get_user_by_id",
new=AsyncMock(return_value=stale_user),
):
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
assert "revoked" in str(exc_info.value.detail).lower()
import asyncio
asyncio.run(get_current_user_from_request(request))
assert exc_info.value.status_code == 401
assert "revoked" in str(exc_info.value.detail).lower()
# ── change-password extension ──────────────────────────────────────────────
@ -429,7 +386,7 @@ def test_token_version_mismatch_rejects():
def test_change_password_request_accepts_new_email():
"""ChangePasswordRequest model accepts optional new_email."""
from app.gateway.routers.auth import ChangePasswordRequest
from app.plugins.auth.api.schemas import ChangePasswordRequest
req = ChangePasswordRequest(
current_password="old",
@ -441,7 +398,7 @@ def test_change_password_request_accepts_new_email():
def test_change_password_request_new_email_optional():
"""ChangePasswordRequest model works without new_email."""
from app.gateway.routers.auth import ChangePasswordRequest
from app.plugins.auth.api.schemas import ChangePasswordRequest
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
assert req.new_email is None
@ -449,7 +406,7 @@ def test_change_password_request_new_email_optional():
def test_login_response_includes_needs_setup():
"""LoginResponse includes needs_setup field."""
from app.gateway.routers.auth import LoginResponse
from app.plugins.auth.api.schemas import LoginResponse
resp = LoginResponse(expires_in=3600, needs_setup=True)
assert resp.needs_setup is True
@ -462,7 +419,7 @@ def test_login_response_includes_needs_setup():
def test_rate_limiter_allows_under_limit():
"""Requests under the limit are allowed."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts
_login_attempts.clear()
_check_rate_limit("192.168.1.1") # Should not raise
@ -470,7 +427,7 @@ def test_rate_limiter_allows_under_limit():
def test_rate_limiter_blocks_after_max_failures():
"""IP is blocked after 5 consecutive failures."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts, _record_login_failure
_login_attempts.clear()
ip = "10.0.0.1"
@ -483,7 +440,7 @@ def test_rate_limiter_blocks_after_max_failures():
def test_rate_limiter_resets_on_success():
"""Successful login clears the failure counter."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
_login_attempts.clear()
ip = "10.0.0.2"
@ -499,7 +456,7 @@ def test_rate_limiter_resets_on_success():
def test_get_client_ip_direct_connection_no_proxy(monkeypatch):
"""Direct mode (no AUTH_TRUSTED_PROXIES): use TCP peer regardless of X-Real-IP."""
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
from app.gateway.routers.auth import _get_client_ip
from app.plugins.auth.api.schemas import _get_client_ip
req = MagicMock()
req.client.host = "203.0.113.42"
@ -514,7 +471,7 @@ def test_get_client_ip_x_real_ip_ignored_when_no_trusted_proxy(monkeypatch):
request to dodge per-IP rate limits in dev / direct mode.
"""
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
from app.gateway.routers.auth import _get_client_ip
from app.plugins.auth.api.schemas import _get_client_ip
req = MagicMock()
req.client.host = "127.0.0.1"
@ -525,7 +482,7 @@ def test_get_client_ip_x_real_ip_ignored_when_no_trusted_proxy(monkeypatch):
def test_get_client_ip_x_real_ip_honored_from_trusted_proxy(monkeypatch):
"""X-Real-IP is honored when the TCP peer matches AUTH_TRUSTED_PROXIES."""
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
from app.gateway.routers.auth import _get_client_ip
from app.plugins.auth.api.schemas import _get_client_ip
req = MagicMock()
req.client.host = "10.5.6.7" # in trusted CIDR
@ -536,7 +493,7 @@ def test_get_client_ip_x_real_ip_honored_from_trusted_proxy(monkeypatch):
def test_get_client_ip_x_real_ip_rejected_from_untrusted_peer(monkeypatch):
"""X-Real-IP is rejected when the TCP peer is NOT in the trusted list."""
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
from app.gateway.routers.auth import _get_client_ip
from app.plugins.auth.api.schemas import _get_client_ip
req = MagicMock()
req.client.host = "8.8.8.8" # NOT in trusted CIDR
@ -547,7 +504,7 @@ def test_get_client_ip_x_real_ip_rejected_from_untrusted_peer(monkeypatch):
def test_get_client_ip_xff_never_honored(monkeypatch):
"""X-Forwarded-For is never used; only X-Real-IP from a trusted peer."""
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
from app.gateway.routers.auth import _get_client_ip
from app.plugins.auth.api.schemas import _get_client_ip
req = MagicMock()
req.client.host = "10.0.0.1"
@ -558,7 +515,7 @@ def test_get_client_ip_xff_never_honored(monkeypatch):
def test_get_client_ip_invalid_trusted_proxy_entry_skipped(monkeypatch, caplog):
"""Garbage entries in AUTH_TRUSTED_PROXIES are warned and skipped."""
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "not-an-ip,10.0.0.0/8")
from app.gateway.routers.auth import _get_client_ip
from app.plugins.auth.api.schemas import _get_client_ip
req = MagicMock()
req.client.host = "10.5.6.7"
@ -569,7 +526,7 @@ def test_get_client_ip_invalid_trusted_proxy_entry_skipped(monkeypatch, caplog):
def test_get_client_ip_no_client_returns_unknown(monkeypatch):
"""No request.client → 'unknown' marker (no crash)."""
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
from app.gateway.routers.auth import _get_client_ip
from app.plugins.auth.api.schemas import _get_client_ip
req = MagicMock()
req.client = None
@ -584,7 +541,7 @@ def test_register_rejects_literal_password():
"""Pydantic validator rejects 'password' as a registration password."""
from pydantic import ValidationError
from app.gateway.routers.auth import RegisterRequest
from app.plugins.auth.api.schemas import RegisterRequest
with pytest.raises(ValidationError) as exc:
RegisterRequest(email="x@example.com", password="password")
@ -595,7 +552,7 @@ def test_register_rejects_common_password_case_insensitive():
"""Case variants of common passwords are also rejected."""
from pydantic import ValidationError
from app.gateway.routers.auth import RegisterRequest
from app.plugins.auth.api.schemas import RegisterRequest
for variant in ["PASSWORD", "Password1", "qwerty123", "letmein1"]:
with pytest.raises(ValidationError):
@ -604,7 +561,7 @@ def test_register_rejects_common_password_case_insensitive():
def test_register_accepts_strong_password():
"""A non-blocklisted password of length >=8 is accepted."""
from app.gateway.routers.auth import RegisterRequest
from app.plugins.auth.api.schemas import RegisterRequest
req = RegisterRequest(email="x@example.com", password="Tr0ub4dor&3-Horse")
assert req.password == "Tr0ub4dor&3-Horse"
@ -614,7 +571,7 @@ def test_change_password_rejects_common_password():
"""The same blocklist applies to change-password."""
from pydantic import ValidationError
from app.gateway.routers.auth import ChangePasswordRequest
from app.plugins.auth.api.schemas import ChangePasswordRequest
with pytest.raises(ValidationError):
ChangePasswordRequest(current_password="anything", new_password="iloveyou")
@ -624,7 +581,7 @@ def test_password_blocklist_keeps_short_passwords_for_length_check():
"""Short passwords still fail the min_length check (not the blocklist)."""
from pydantic import ValidationError
from app.gateway.routers.auth import RegisterRequest
from app.plugins.auth.api.schemas import RegisterRequest
with pytest.raises(ValidationError) as exc:
RegisterRequest(email="x@example.com", password="abc")
@ -639,16 +596,17 @@ def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
import logging
import app.gateway.auth.config as config_module
import app.plugins.auth.runtime.config_state as config_module
from app.plugins.auth.runtime.config_state import reset_auth_config
config_module._auth_config = None
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
with caplog.at_level(logging.WARNING):
reset_auth_config()
config = config_module.get_auth_config()
assert config.jwt_secret # non-empty ephemeral secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
# Cleanup
config_module._auth_config = None
reset_auth_config()

View File

@ -5,7 +5,8 @@ from unittest.mock import patch
import pytest
from app.gateway.auth.config import AuthConfig
from app.plugins.auth.domain.config import AuthConfig
from app.plugins.auth.runtime.config_state import reset_auth_config
def test_auth_config_defaults():
@ -25,30 +26,28 @@ def test_auth_config_token_expiry_range():
def test_auth_config_from_env():
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
with patch.dict(os.environ, env, clear=False):
import app.gateway.auth.config as cfg
import app.plugins.auth.runtime.config_state as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
reset_auth_config()
config = cfg.get_auth_config()
assert config.jwt_secret == "test-jwt-secret-from-env"
finally:
cfg._auth_config = old
reset_auth_config()
def test_auth_config_missing_secret_generates_ephemeral(caplog):
import logging
import app.gateway.auth.config as cfg
import app.plugins.auth.runtime.config_state as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
reset_auth_config()
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
finally:
cfg._auth_config = old
reset_auth_config()

View File

@ -0,0 +1,157 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.plugins.auth.domain.config import AuthConfig
from app.plugins.auth.security.dependencies import (
get_current_user_from_request,
get_current_user_id,
get_optional_user_from_request,
)
from app.plugins.auth.domain.jwt import create_access_token
from app.plugins.auth.runtime.config_state import set_auth_config
from app.plugins.auth.storage import DbUserRepository, UserCreate
from store.persistence import MappedBase
from app.plugins.auth.storage.models import User as UserModel # noqa: F401
_TEST_SECRET = "test-secret-auth-dependencies-min-32"
@pytest.fixture(autouse=True)
def _setup_auth_config():
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
yield
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
async def _make_request(tmp_path, *, cookie: str | None = None, users: list[UserCreate] | None = None):
engine = create_async_engine(
f"sqlite+aiosqlite:///{tmp_path / 'auth-deps.db'}",
future=True,
)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
session = session_factory()
if users:
repo = DbUserRepository(session)
for user in users:
await repo.create_user(user)
await session.commit()
request = SimpleNamespace(
cookies={"access_token": cookie} if cookie is not None else {},
state=SimpleNamespace(_auth_session=session),
)
return request, session, engine
class TestAuthDependencies:
@pytest.mark.anyio
async def test_no_cookie_returns_401(self, tmp_path):
request, session, engine = await _make_request(tmp_path)
try:
with pytest.raises(HTTPException) as exc_info:
await get_current_user_from_request(request)
finally:
await session.close()
await engine.dispose()
assert exc_info.value.status_code == 401
assert exc_info.value.detail["code"] == "not_authenticated"
@pytest.mark.anyio
async def test_invalid_token_returns_401(self, tmp_path):
request, session, engine = await _make_request(tmp_path, cookie="garbage")
try:
with pytest.raises(HTTPException) as exc_info:
await get_current_user_from_request(request)
finally:
await session.close()
await engine.dispose()
assert exc_info.value.status_code == 401
assert exc_info.value.detail["code"] == "token_invalid"
@pytest.mark.anyio
async def test_missing_user_returns_401(self, tmp_path):
token = create_access_token("missing-user", token_version=0)
request, session, engine = await _make_request(tmp_path, cookie=token)
try:
with pytest.raises(HTTPException) as exc_info:
await get_current_user_from_request(request)
finally:
await session.close()
await engine.dispose()
assert exc_info.value.status_code == 401
assert exc_info.value.detail["code"] == "user_not_found"
@pytest.mark.anyio
async def test_token_version_mismatch_returns_401(self, tmp_path):
token = create_access_token("user-1", token_version=0)
request, session, engine = await _make_request(
tmp_path,
cookie=token,
users=[
UserCreate(
id="user-1",
email="user1@example.com",
token_version=2,
)
],
)
try:
with pytest.raises(HTTPException) as exc_info:
await get_current_user_from_request(request)
finally:
await session.close()
await engine.dispose()
assert exc_info.value.status_code == 401
assert exc_info.value.detail["code"] == "token_invalid"
@pytest.mark.anyio
async def test_valid_token_returns_user(self, tmp_path):
token = create_access_token("user-2", token_version=3)
request, session, engine = await _make_request(
tmp_path,
cookie=token,
users=[
UserCreate(
id="user-2",
email="user2@example.com",
token_version=3,
)
],
)
try:
user = await get_current_user_from_request(request)
user_id = await get_current_user_id(request)
finally:
await session.close()
await engine.dispose()
assert user.id == "user-2"
assert user.email == "user2@example.com"
assert user_id == "user-2"
@pytest.mark.anyio
async def test_optional_user_returns_none_on_failure(self, tmp_path):
request, session, engine = await _make_request(tmp_path, cookie="bad-token")
try:
user = await get_optional_user_from_request(request)
finally:
await session.close()
await engine.dispose()
assert user is None

View File

@ -4,9 +4,10 @@ from datetime import UTC, datetime, timedelta
import jwt as pyjwt
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import create_access_token, decode_token
from app.plugins.auth.domain.config import AuthConfig
from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.plugins.auth.domain.jwt import create_access_token, decode_token
from app.plugins.auth.runtime.config_state import set_auth_config
def test_auth_error_code_values():
@ -56,7 +57,7 @@ def test_decode_token_returns_token_error_on_expired():
def test_decode_token_returns_token_error_on_bad_signature():
_setup_config()
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
token = pyjwt.encode(payload, "wrong-secret-key-for-tests-minimum-32", algorithm="HS256")
result = decode_token(token)
assert result == TokenError.INVALID_SIGNATURE

View File

@ -1,9 +1,11 @@
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
from types import SimpleNamespace
import pytest
from starlette.testclient import TestClient
from app.gateway.auth_middleware import AuthMiddleware, _is_public
from app.plugins.auth.security.middleware import AuthMiddleware, _is_public
# ── _is_public unit tests ─────────────────────────────────────────────────
@ -165,7 +167,8 @@ def test_protected_path_with_junk_cookie_rejected(client):
"""Junk cookie → 401. Middleware strictly validates the JWT now
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
tokens through to the route handler."""
res = client.get("/api/models", cookies={"access_token": "some-token"})
client.cookies.set("access_token", "some-token")
res = client.get("/api/models")
assert res.status_code == 401
@ -220,3 +223,44 @@ def test_unknown_endpoint_with_junk_cookie_rejected(client):
client.cookies.set("access_token", "tok")
res = client.get("/api/future-endpoint")
assert res.status_code == 401
def test_middleware_populates_request_user_and_auth(monkeypatch):
from fastapi import FastAPI, Request
from app.plugins.auth.security import middleware as middleware_module
user = SimpleNamespace(id="user-123", email="test@example.com")
async def _fake_get_current_user_from_request(request):
return user
monkeypatch.setattr(
middleware_module,
"get_current_user_from_request",
_fake_get_current_user_from_request,
)
app = FastAPI()
app.add_middleware(AuthMiddleware)
@app.get("/api/models")
async def models_get(request: Request):
return {
"request_user_id": request.user.id,
"state_user_id": request.state.user.id,
"request_auth_user_id": request.auth.user.id,
"state_auth_user_id": request.state.auth.user.id,
}
client = TestClient(app)
client.cookies.set("access_token", "valid")
response = client.get("/api/models")
assert response.status_code == 200
assert response.json() == {
"request_user_id": "user-123",
"state_user_id": "user-123",
"request_auth_user_id": "user-123",
"state_auth_user_id": "user-123",
}

View File

@ -0,0 +1,97 @@
from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
import pytest
from starlette.requests import Request
from unittest.mock import AsyncMock
from app.plugins.auth.authorization import AuthContext, Permissions
from app.plugins.auth.authorization.policies import require_thread_owner
from app.plugins.auth.domain.models import User
def _make_auth_context() -> AuthContext:
user = User(id=uuid4(), email="user@example.com", password_hash="hash")
return AuthContext(user=user, permissions=[Permissions.THREADS_READ, Permissions.RUNS_READ])
def _make_request(*, thread_repo, run_repo=None, checkpointer=None) -> Request:
app = SimpleNamespace(
state=SimpleNamespace(
thread_meta_repo=thread_repo,
run_store=run_repo,
checkpointer=checkpointer,
)
)
scope = {
"type": "http",
"method": "GET",
"path": "/api/threads/thread-1/runs",
"headers": [],
"app": app,
"route": SimpleNamespace(path="/api/threads/{thread_id}/runs"),
"path_params": {"thread_id": "thread-1"},
}
return Request(scope)
@pytest.mark.anyio
async def test_require_thread_owner_uses_thread_row_user_id() -> None:
auth = _make_auth_context()
thread_repo = SimpleNamespace(
get_thread_meta=AsyncMock(
return_value=SimpleNamespace(
user_id=str(auth.user.id),
metadata={"user_id": "someone-else"},
)
)
)
request = _make_request(thread_repo=thread_repo)
await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True)
@pytest.mark.anyio
async def test_require_thread_owner_falls_back_to_user_owned_runs() -> None:
auth = _make_auth_context()
thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None))
run_repo = SimpleNamespace(
list_by_thread=AsyncMock(return_value=[{"run_id": "run-1", "thread_id": "thread-1"}])
)
request = _make_request(thread_repo=thread_repo, run_repo=run_repo)
await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True)
run_repo.list_by_thread.assert_awaited_once_with("thread-1", limit=1, user_id=str(auth.user.id))
@pytest.mark.anyio
async def test_require_thread_owner_falls_back_to_checkpoint_threads() -> None:
auth = _make_auth_context()
thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None))
run_repo = SimpleNamespace(list_by_thread=AsyncMock(return_value=[]))
checkpointer = SimpleNamespace(aget_tuple=AsyncMock(return_value=object()))
request = _make_request(thread_repo=thread_repo, run_repo=run_repo, checkpointer=checkpointer)
await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True)
checkpointer.aget_tuple.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
)
@pytest.mark.anyio
async def test_require_thread_owner_denies_missing_thread() -> None:
auth = _make_auth_context()
thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None))
run_repo = SimpleNamespace(list_by_thread=AsyncMock(return_value=[]))
checkpointer = SimpleNamespace(aget_tuple=AsyncMock(return_value=None))
request = _make_request(thread_repo=thread_repo, run_repo=run_repo, checkpointer=checkpointer)
with pytest.raises(Exception) as exc_info:
await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True)
assert getattr(exc_info.value, "status_code", None) == 404
assert getattr(exc_info.value, "detail", "") == "Thread thread-1 not found"

View File

@ -0,0 +1,146 @@
from __future__ import annotations
from types import SimpleNamespace
from uuid import uuid4
import pytest
from fastapi import APIRouter, FastAPI
from starlette.requests import Request
from app.plugins.auth.authorization import AuthContext
from app.plugins.auth.domain.models import User
from app.plugins.auth.injection import load_route_policy_registry, validate_route_policy_registry
from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry, RoutePolicySpec
from app.plugins.auth.injection.route_guard import enforce_route_policy
from app.plugins.auth.injection.route_injector import install_route_guards
def test_load_route_policy_registry_flattens_yaml_sections() -> None:
registry = load_route_policy_registry()
public_spec = registry.get("POST", "/api/v1/auth/login/local")
assert public_spec is not None
assert public_spec.public is True
run_stream_spec = registry.get("GET", "/api/threads/{thread_id}/runs/{run_id}/stream")
assert run_stream_spec is not None
assert run_stream_spec.capability == "runs:read"
assert run_stream_spec.policies == ("owner:run",)
post_stream_spec = registry.get("POST", "/api/threads/{thread_id}/runs/{run_id}/stream")
assert post_stream_spec == run_stream_spec
def test_validate_route_policy_registry_rejects_missing_entry() -> None:
app = FastAPI()
router = APIRouter()
@router.get("/api/needs-policy")
async def needs_policy() -> dict[str, bool]:
return {"ok": True}
app.include_router(router)
registry = RoutePolicyRegistry([])
with pytest.raises(RuntimeError, match="Missing route policy entries"):
validate_route_policy_registry(app, registry)
def test_install_route_guards_appends_route_dependency() -> None:
app = FastAPI()
router = APIRouter()
@router.get("/api/demo")
async def demo() -> dict[str, bool]:
return {"ok": True}
app.include_router(router)
route = next(route for route in app.routes if getattr(route, "path", None) == "/api/demo")
before = len(route.dependencies)
install_route_guards(app)
assert len(route.dependencies) == before + 1
assert route.dependencies[-1].dependency is enforce_route_policy
@pytest.mark.anyio
async def test_enforce_route_policy_denies_missing_capability() -> None:
user = User(id=uuid4(), email="user@example.com", password_hash="hash")
auth = AuthContext(user=user, permissions=["threads:read"])
registry = RoutePolicyRegistry(
[
SimpleNamespace(
method="GET",
path="/api/threads/{thread_id}/uploads/list",
spec=RoutePolicySpec(capability="threads:delete"),
matches_request=lambda *_args, **_kwargs: True,
)
]
)
app = SimpleNamespace(state=SimpleNamespace(auth_route_policy_registry=registry))
scope = {
"type": "http",
"method": "GET",
"path": "/api/threads/thread-1/uploads/list",
"headers": [],
"app": app,
"route": SimpleNamespace(path="/api/threads/{thread_id}/uploads/list"),
"path_params": {"thread_id": "thread-1"},
"auth": auth,
}
request = Request(scope)
request.state.auth = auth
with pytest.raises(Exception) as exc_info:
await enforce_route_policy(request)
assert getattr(exc_info.value, "status_code", None) == 403
@pytest.mark.anyio
async def test_enforce_route_policy_runs_owner_policy(monkeypatch: pytest.MonkeyPatch) -> None:
user = User(id=uuid4(), email="user@example.com", password_hash="hash")
auth = AuthContext(user=user, permissions=["threads:read"])
registry = RoutePolicyRegistry(
[
SimpleNamespace(
method="GET",
path="/api/threads/{thread_id}/state",
spec=RoutePolicySpec(capability="threads:read", policies=("owner:thread",)),
matches_request=lambda *_args, **_kwargs: True,
)
]
)
called: dict[str, object] = {}
async def fake_owner_check(request: Request, auth_context: AuthContext, *, thread_id: str, require_existing: bool) -> None:
called["request"] = request
called["auth"] = auth_context
called["thread_id"] = thread_id
called["require_existing"] = require_existing
monkeypatch.setattr("app.plugins.auth.injection.route_guard.require_thread_owner", fake_owner_check)
app = SimpleNamespace(state=SimpleNamespace(auth_route_policy_registry=registry))
scope = {
"type": "http",
"method": "GET",
"path": "/api/threads/thread-1/state",
"headers": [],
"app": app,
"route": SimpleNamespace(path="/api/threads/{thread_id}/state"),
"path_params": {"thread_id": "thread-1"},
"auth": auth,
}
request = Request(scope)
request.state.auth = auth
await enforce_route_policy(request)
assert called["thread_id"] == "thread-1"
assert called["auth"] is auth
assert called["require_existing"] is True

View File

@ -0,0 +1,86 @@
from __future__ import annotations
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.plugins.auth.domain.service import AuthService, AuthServiceError
from app.plugins.auth.storage.models import User as UserModel # noqa: F401
from store.persistence import MappedBase
async def _make_service(tmp_path):
engine = create_async_engine(
f"sqlite+aiosqlite:///{tmp_path / 'auth-service.db'}",
future=True,
)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
return engine, AuthService(session_factory)
class TestAuthService:
@pytest.mark.anyio
async def test_register_and_login_local(self, tmp_path):
engine, service = await _make_service(tmp_path)
try:
created = await service.register("user@example.com", "Str0ng!Pass99")
logged_in = await service.login_local("user@example.com", "Str0ng!Pass99")
finally:
await engine.dispose()
assert created.email == "user@example.com"
assert created.password_hash is not None
assert logged_in.id == created.id
@pytest.mark.anyio
async def test_register_duplicate_email_raises(self, tmp_path):
engine, service = await _make_service(tmp_path)
try:
await service.register("dupe@example.com", "Str0ng!Pass99")
with pytest.raises(AuthServiceError) as exc_info:
await service.register("dupe@example.com", "An0ther!Pass99")
finally:
await engine.dispose()
assert exc_info.value.code.value == "email_already_exists"
@pytest.mark.anyio
async def test_initialize_admin_only_once(self, tmp_path):
engine, service = await _make_service(tmp_path)
try:
admin = await service.initialize_admin("admin@example.com", "Str0ng!Pass99")
with pytest.raises(AuthServiceError) as exc_info:
await service.initialize_admin("other@example.com", "An0ther!Pass99")
finally:
await engine.dispose()
assert admin.system_role == "admin"
assert admin.needs_setup is False
assert exc_info.value.code.value == "system_already_initialized"
@pytest.mark.anyio
async def test_change_password_updates_token_version_and_clears_setup(self, tmp_path):
engine, service = await _make_service(tmp_path)
try:
user = await service.register("setup@example.com", "Str0ng!Pass99")
user.needs_setup = True
updated = await service.change_password(
user,
current_password="Str0ng!Pass99",
new_password="N3wer!Pass99",
new_email="final@example.com",
)
relogged = await service.login_local("final@example.com", "N3wer!Pass99")
finally:
await engine.dispose()
assert updated.email == "final@example.com"
assert updated.needs_setup is False
assert updated.token_version == 1
assert relogged.id == updated.id

View File

@ -16,10 +16,11 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import ValidationError
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.csrf_middleware import (
from app.plugins.auth.domain.config import AuthConfig
from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.plugins.auth.domain.jwt import decode_token
from app.plugins.auth.runtime.config_state import set_auth_config
from app.plugins.auth.security.csrf import (
CSRF_COOKIE_NAME,
CSRF_HEADER_NAME,
CSRFMiddleware,
@ -34,28 +35,8 @@ _TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
@pytest.fixture(autouse=True)
def _persistence_engine(tmp_path):
"""Initialise a per-test SQLite engine + reset cached provider singletons.
The auth tests call real HTTP handlers that go through
``SQLiteUserRepository`` ``get_session_factory``. Each test gets
a fresh DB plus a clean ``deps._cached_*`` so the cached provider
does not hold a dangling reference to the previous test's engine.
"""
import asyncio
from app.gateway import deps
from deerflow.persistence.engine import close_engine, init_engine
url = f"sqlite+aiosqlite:///{tmp_path}/auth_types.db"
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
deps._cached_local_provider = None
deps._cached_repo = None
try:
yield
finally:
deps._cached_local_provider = None
deps._cached_repo = None
asyncio.run(close_engine())
"""Per-test auth config fixture placeholder."""
yield
def _setup_config():
@ -174,7 +155,7 @@ def test_decode_token_invalid_sig_maps_to_token_invalid_code():
import jwt as pyjwt
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
token = pyjwt.encode(payload, "wrong-key-for-tests-minimum-32-chars", algorithm="HS256")
result = decode_token(token)
assert result == TokenError.INVALID_SIGNATURE
@ -197,7 +178,7 @@ def test_decode_token_malformed_maps_to_token_invalid_code():
def test_login_response_model_has_no_access_token():
"""LoginResponse should NOT contain access_token field (RFC-001)."""
from app.gateway.routers.auth import LoginResponse
from app.plugins.auth.api.schemas import LoginResponse
resp = LoginResponse(expires_in=604800)
d = resp.model_dump()
@ -208,7 +189,7 @@ def test_login_response_model_has_no_access_token():
def test_login_response_model_fields():
"""LoginResponse has expires_in and needs_setup."""
from app.gateway.routers.auth import LoginResponse
from app.plugins.auth.api.schemas import LoginResponse
fields = set(LoginResponse.model_fields.keys())
assert fields == {"expires_in", "needs_setup"}
@ -219,7 +200,7 @@ def test_login_response_model_fields():
def test_auth_config_token_expiry_used_in_login_response():
"""LoginResponse.expires_in should come from config.token_expiry_days."""
from app.gateway.routers.auth import LoginResponse
from app.plugins.auth.api.schemas import LoginResponse
expected_seconds = 14 * 24 * 3600
resp = LoginResponse(expires_in=expected_seconds)
@ -231,7 +212,7 @@ def test_auth_config_token_expiry_used_in_login_response():
def test_user_response_system_role_literal():
"""UserResponse.system_role should only accept 'admin' or 'user'."""
from app.gateway.auth.models import UserResponse
from app.plugins.auth.domain.models import UserResponse
# Valid roles
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
@ -243,7 +224,7 @@ def test_user_response_system_role_literal():
def test_user_response_rejects_invalid_role():
"""UserResponse should reject invalid system_role values."""
from app.gateway.auth.models import UserResponse
from app.plugins.auth.domain.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com", system_role="superadmin")
@ -263,7 +244,7 @@ def test_get_current_user_no_cookie_returns_not_authenticated():
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
from app.plugins.auth.security.dependencies import get_current_user_from_request
mock_request = type("MockRequest", (), {"cookies": {}})()
with pytest.raises(HTTPException) as exc_info:
@ -279,7 +260,7 @@ def test_get_current_user_expired_token_returns_token_expired():
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
from app.plugins.auth.security.dependencies import get_current_user_from_request
_setup_config()
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
@ -299,11 +280,11 @@ def test_get_current_user_invalid_token_returns_token_invalid():
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
from app.plugins.auth.security.dependencies import get_current_user_from_request
_setup_config()
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
token = pyjwt.encode(payload, "wrong-secret-key-for-tests-minimum-32", algorithm="HS256")
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
with pytest.raises(HTTPException) as exc_info:
@ -319,7 +300,7 @@ def test_get_current_user_malformed_token_returns_token_invalid():
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
from app.plugins.auth.security.dependencies import get_current_user_from_request
_setup_config()
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
@ -380,19 +361,19 @@ def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
import logging
import app.gateway.auth.config as cfg
import app.plugins.auth.runtime.config_state as cfg
from app.plugins.auth.runtime.config_state import reset_auth_config
old = cfg._auth_config
cfg._auth_config = None
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
reset_auth_config()
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
finally:
cfg._auth_config = old
reset_auth_config()
# ── CSRF middleware integration (unhappy paths) ──────────────────────
@ -485,7 +466,7 @@ def test_csrf_middleware_sets_cookie_on_auth_endpoint():
def test_user_response_missing_required_fields():
"""UserResponse with missing fields → ValidationError."""
from app.gateway.auth.models import UserResponse
from app.plugins.auth.domain.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1") # missing email, system_role
@ -496,7 +477,7 @@ def test_user_response_missing_required_fields():
def test_user_response_empty_string_role_rejected():
"""Empty string is not a valid role."""
from app.gateway.auth.models import UserResponse
from app.plugins.auth.domain.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com", system_role="")
@ -514,20 +495,15 @@ def _make_auth_app():
return create_app()
def _get_auth_client():
"""Get TestClient for auth API contract tests."""
return TestClient(_make_auth_app())
def test_api_auth_me_no_cookie_returns_structured_401():
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
_setup_config()
client = _get_auth_client()
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "not_authenticated"
assert "message" in body["detail"]
with TestClient(_make_auth_app()) as client:
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "not_authenticated"
assert "message" in body["detail"]
def test_api_auth_me_expired_token_returns_structured_401():
@ -536,75 +512,70 @@ def test_api_auth_me_expired_token_returns_structured_401():
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
client = _get_auth_client()
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_expired"
with TestClient(_make_auth_app()) as client:
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_expired"
def test_api_auth_me_invalid_sig_returns_structured_401():
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
_setup_config()
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
token = pyjwt.encode(payload, "wrong-key-for-tests-minimum-32-chars", algorithm="HS256")
client = _get_auth_client()
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_invalid"
with TestClient(_make_auth_app()) as client:
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_invalid"
def test_api_login_bad_credentials_returns_structured_401():
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
)
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "invalid_credentials"
with TestClient(_make_auth_app()) as client:
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
)
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "invalid_credentials"
def test_api_login_success_no_token_in_body():
"""Successful login → response body has expires_in but NOT access_token."""
_setup_config()
client = _get_auth_client()
# Register first
client.post(
"/api/v1/auth/register",
json={"email": "contract-test@test.com", "password": "securepassword123"},
)
# Login
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "contract-test@test.com", "password": "securepassword123"},
)
assert resp.status_code == 200
body = resp.json()
assert "expires_in" in body
assert "access_token" not in body
# Token should be in cookie, not body
assert "access_token" in resp.cookies
with TestClient(_make_auth_app()) as client:
client.post(
"/api/v1/auth/register",
json={"email": "contract-test@test.com", "password": "securepassword123"},
)
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "contract-test@test.com", "password": "securepassword123"},
)
assert resp.status_code == 200
body = resp.json()
assert "expires_in" in body
assert "access_token" not in body
assert "access_token" in resp.cookies
def test_api_register_duplicate_returns_structured_400():
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
_setup_config()
client = _get_auth_client()
email = "dup-contract-test@test.com"
# First register
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
# Duplicate
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "AnotherStr0ngPwd!"})
assert resp.status_code == 400
body = resp.json()
assert body["detail"]["code"] == "email_already_exists"
with TestClient(_make_auth_app()) as client:
email = "dup-contract-test@test.com"
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "AnotherStr0ngPwd!"})
assert resp.status_code == 400
body = resp.json()
assert body["detail"]["code"] == "email_already_exists"
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
@ -622,80 +593,80 @@ def _get_set_cookie_headers(resp) -> list[str]:
def test_register_http_cookie_httponly_true_secure_false():
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("http-cookie"), "password": "Tr0ub4dor3a"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" not in cookie_header.lower().replace("samesite", "")
with TestClient(_make_auth_app()) as client:
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("http-cookie"), "password": "Tr0ub4dor3a"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" not in cookie_header.lower().replace("samesite", "")
def test_register_https_cookie_httponly_true_secure_true():
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("https-cookie"), "password": "Tr0ub4dor3a"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
assert "max-age" in cookie_header.lower()
with TestClient(_make_auth_app()) as client:
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("https-cookie"), "password": "Tr0ub4dor3a"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
assert "max-age" in cookie_header.lower()
def test_login_https_sets_secure_cookie():
"""HTTPS login → access_token cookie has secure flag."""
_setup_config()
client = _get_auth_client()
email = _unique_email("https-login")
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
resp = client.post(
"/api/v1/auth/login/local",
data={"username": email, "password": "Tr0ub4dor3a"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 200
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
with TestClient(_make_auth_app()) as client:
email = _unique_email("https-login")
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
resp = client.post(
"/api/v1/auth/login/local",
data={"username": email, "password": "Tr0ub4dor3a"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 200
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
def test_csrf_cookie_secure_on_https():
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-https"), "password": "Tr0ub4dor3a"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
csrf_header = csrf_cookies[0]
assert "secure" in csrf_header.lower()
assert "httponly" not in csrf_header.lower()
with TestClient(_make_auth_app()) as client:
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-https"), "password": "Tr0ub4dor3a"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
csrf_header = csrf_cookies[0]
assert "secure" in csrf_header.lower()
assert "httponly" not in csrf_header.lower()
def test_csrf_cookie_not_secure_on_http():
"""HTTP register → csrf_token cookie does NOT have secure flag."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-http"), "password": "Tr0ub4dor3a"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
csrf_header = csrf_cookies[0]
assert "secure" not in csrf_header.lower().replace("samesite", "")
with TestClient(_make_auth_app()) as client:
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-http"), "password": "Tr0ub4dor3a"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
csrf_header = csrf_cookies[0]
assert "secure" not in csrf_header.lower().replace("samesite", "")

View File

@ -38,6 +38,8 @@ def mock_app_config():
config = MagicMock()
config.models = [model]
config.database = None
config.checkpointer = None
return config
@ -86,10 +88,12 @@ class TestClientInit:
def test_custom_config_path(self, mock_app_config):
with (
patch("deerflow.client.reload_app_config") as mock_reload,
patch("store.config.app_config.reload_app_config") as mock_storage_reload,
patch("deerflow.client.get_app_config", return_value=mock_app_config),
):
DeerFlowClient(config_path="/tmp/custom.yaml")
mock_reload.assert_called_once_with("/tmp/custom.yaml")
mock_storage_reload.assert_called_once_with("/tmp/custom.yaml")
def test_checkpointer_stored(self, mock_app_config):
cp = MagicMock()
@ -97,6 +101,59 @@ class TestClientInit:
c = DeerFlowClient(checkpointer=cp)
assert c._checkpointer is cp
def test_resolve_checkpointer_config_reads_storage_sqlite(self, mock_app_config):
storage_app_config = MagicMock()
storage = MagicMock()
storage.driver = "sqlite"
storage.sqlite_storage_path = "/tmp/deerflow.db"
storage_app_config.storage = storage
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient()
with patch("store.config.app_config.get_app_config", return_value=storage_app_config):
resolved = c._resolve_checkpointer_config()
assert resolved == {
"backend": "sqlite",
"connection_string": "/tmp/deerflow.db",
}
def test_resolve_checkpointer_config_reads_storage_postgres(self, mock_app_config):
storage_app_config = MagicMock()
storage = MagicMock()
storage.driver = "postgres"
storage.username = "user"
storage.password = "pass"
storage.host = "localhost"
storage.port = 5432
storage.db_name = "deerflow"
storage_app_config.storage = storage
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient()
with patch("store.config.app_config.get_app_config", return_value=storage_app_config):
resolved = c._resolve_checkpointer_config()
assert resolved == {
"backend": "postgres",
"connection_string": "postgresql://user:pass@localhost:5432/deerflow",
}
def test_resolve_checkpointer_config_rejects_mysql(self, mock_app_config):
storage_app_config = MagicMock()
storage = MagicMock()
storage.driver = "mysql"
storage_app_config.storage = storage
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient()
with patch("store.config.app_config.get_app_config", return_value=storage_app_config):
with pytest.raises(ValueError, match="does not support a MySQL checkpointer"):
c._resolve_checkpointer_config()
# ---------------------------------------------------------------------------
# list_models / list_skills / get_memory
@ -817,7 +874,7 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
):
client._agent_name = "custom-agent"
client._available_skills = {"test_skill"}
@ -842,7 +899,7 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer),
patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer),
):
client._ensure_agent(config)
@ -867,7 +924,7 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
@ -886,7 +943,7 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None),
patch.object(client, "_get_active_checkpointer", return_value=None),
):
client._ensure_agent(config)
@ -1015,7 +1072,7 @@ class TestThreadQueries:
mock_checkpointer = MagicMock()
mock_checkpointer.list.return_value = []
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
with patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer):
# No internal checkpointer, should fetch from provider
result = client.list_threads()
@ -1069,7 +1126,7 @@ class TestThreadQueries:
mock_checkpointer = MagicMock()
mock_checkpointer.list.return_value = []
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
with patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer):
result = client.get_thread("t99")
assert result["thread_id"] == "t99"
@ -1490,7 +1547,7 @@ class TestUploads:
class TestArtifacts:
def test_get_artifact(self, client):
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
with tempfile.TemporaryDirectory() as tmp:
paths = Paths(base_dir=tmp)
@ -1506,7 +1563,7 @@ class TestArtifacts:
assert "text" in mime
def test_get_artifact_not_found(self, client):
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
with tempfile.TemporaryDirectory() as tmp:
paths = Paths(base_dir=tmp)
@ -1522,7 +1579,7 @@ class TestArtifacts:
client.get_artifact("t1", "bad/path/file.txt")
def test_get_artifact_path_traversal(self, client):
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
with tempfile.TemporaryDirectory() as tmp:
paths = Paths(base_dir=tmp)
@ -1711,7 +1768,7 @@ class TestScenarioFileLifecycle:
def test_upload_then_read_artifact(self, client):
"""Upload a file, simulate agent producing artifact, read it back."""
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
@ -1859,7 +1916,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config_a)
first_agent = client._agent
@ -1887,7 +1944,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
client._ensure_agent(config)
@ -1912,7 +1969,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
client.reset_agent()
@ -1970,7 +2027,7 @@ class TestScenarioThreadIsolation:
def test_artifacts_isolated_per_thread(self, client):
"""Artifacts in thread-A are not accessible from thread-B."""
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
with tempfile.TemporaryDirectory() as tmp:
paths = Paths(base_dir=tmp)
@ -2882,7 +2939,7 @@ class TestUploadDeleteSymlink:
class TestArtifactHardening:
def test_artifact_directory_rejected(self, client):
"""get_artifact rejects paths that resolve to a directory."""
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
with tempfile.TemporaryDirectory() as tmp:
paths = Paths(base_dir=tmp)
@ -2896,7 +2953,7 @@ class TestArtifactHardening:
def test_artifact_leading_slash_stripped(self, client):
"""Paths with leading slash are handled correctly."""
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
with tempfile.TemporaryDirectory() as tmp:
paths = Paths(base_dir=tmp)
@ -3015,7 +3072,7 @@ class TestBugArtifactPrefixMatchTooLoose:
def test_exact_prefix_without_subpath_accepted(self, client):
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
with tempfile.TemporaryDirectory() as tmp:
paths = Paths(base_dir=tmp)

View File

@ -262,7 +262,7 @@ class TestFileUploadIntegration:
# Physically exists
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists()
@ -473,7 +473,7 @@ class TestArtifactAccess:
def test_get_artifact_happy_path(self, e2e_env):
"""Write a file to outputs, then read it back via get_artifact()."""
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
tid = str(uuid.uuid4())
@ -490,7 +490,7 @@ class TestArtifactAccess:
def test_get_artifact_nested_path(self, e2e_env):
"""Artifacts in subdirectories are accessible."""
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
tid = str(uuid.uuid4())

View File

@ -8,6 +8,7 @@ They are skipped in CI and must be run explicitly:
import json
import os
import warnings
from pathlib import Path
import pytest
@ -16,6 +17,12 @@ from deerflow.client import DeerFlowClient, StreamEvent
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.uploads.manager import PathTraversalError
warnings.filterwarnings(
"ignore",
message=r"Pydantic serializer warnings:.*field_name='context'.*",
category=UserWarning,
)
# Skip entire module in CI or when no config.yaml exists
_skip_reason = None
if os.environ.get("CI"):

View File

@ -0,0 +1,392 @@
"""Tests for current feedback storage adapters and follow-up association."""
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.infra.storage import AppRunEventStore, FeedbackStoreAdapter, RunStoreAdapter
from store.persistence import MappedBase
async def _make_feedback_repo(tmp_path):
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}", future=True)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
class _FeedbackRepoCompat:
def __init__(self, session_factory):
self._repo = FeedbackStoreAdapter(session_factory)
async def create(self, **kwargs):
return await self._repo.create(
run_id=kwargs["run_id"],
thread_id=kwargs["thread_id"],
rating=kwargs["rating"],
owner_id=kwargs.get("owner_id"),
user_id=kwargs.get("user_id"),
message_id=kwargs.get("message_id"),
comment=kwargs.get("comment"),
)
async def get(self, feedback_id):
return await self._repo.get(feedback_id)
async def list_by_run(self, thread_id, run_id, user_id=None, limit=100):
rows = await self._repo.list_by_run(thread_id, run_id, user_id=user_id, limit=limit)
return rows
async def list_by_thread(self, thread_id, limit=100):
return await self._repo.list_by_thread(thread_id, limit=limit)
async def delete(self, feedback_id):
return await self._repo.delete(feedback_id)
async def aggregate_by_run(self, thread_id, run_id):
return await self._repo.aggregate_by_run(thread_id, run_id)
async def upsert(self, **kwargs):
return await self._repo.upsert(
run_id=kwargs["run_id"],
thread_id=kwargs["thread_id"],
rating=kwargs["rating"],
user_id=kwargs.get("user_id"),
comment=kwargs.get("comment"),
)
async def delete_by_run(self, *, thread_id, run_id, user_id):
return await self._repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)
async def list_by_thread_grouped(self, thread_id, user_id):
return await self._repo.list_by_thread_grouped(thread_id, user_id=user_id)
return engine, session_factory, _FeedbackRepoCompat(session_factory)
# -- FeedbackRepository --
class TestFeedbackRepository:
@pytest.mark.anyio
async def test_create_positive(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1)
assert record["feedback_id"]
assert record["rating"] == 1
assert record["run_id"] == "r1"
assert record["thread_id"] == "t1"
assert "created_at" in record
await engine.dispose()
@pytest.mark.anyio
async def test_create_negative_with_comment(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
record = await repo.create(
run_id="r1",
thread_id="t1",
rating=-1,
comment="Response was inaccurate",
)
assert record["rating"] == -1
assert record["comment"] == "Response was inaccurate"
await engine.dispose()
@pytest.mark.anyio
async def test_create_with_message_id(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42")
assert record["message_id"] == "msg-42"
await engine.dispose()
@pytest.mark.anyio
async def test_create_with_owner(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
assert record["user_id"] == "user-1"
await engine.dispose()
@pytest.mark.anyio
async def test_create_uses_owner_id_fallback(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="owner-1")
assert record["user_id"] == "owner-1"
assert record["owner_id"] == "owner-1"
await engine.dispose()
@pytest.mark.anyio
async def test_create_invalid_rating_zero(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
with pytest.raises(ValueError):
await repo.create(run_id="r1", thread_id="t1", rating=0)
await engine.dispose()
@pytest.mark.anyio
async def test_create_invalid_rating_five(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
with pytest.raises(ValueError):
await repo.create(run_id="r1", thread_id="t1", rating=5)
await engine.dispose()
@pytest.mark.anyio
async def test_get(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
fetched = await repo.get(created["feedback_id"])
assert fetched is not None
assert fetched["feedback_id"] == created["feedback_id"]
assert fetched["rating"] == 1
await engine.dispose()
@pytest.mark.anyio
async def test_get_nonexistent(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
assert await repo.get("nonexistent") is None
await engine.dispose()
@pytest.mark.anyio
async def test_list_by_run(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
results = await repo.list_by_run("t1", "r1", user_id=None)
assert len(results) == 2
assert all(r["run_id"] == "r1" for r in results)
await engine.dispose()
@pytest.mark.anyio
async def test_list_by_run_filters_thread_even_with_same_run_id(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
await repo.create(run_id="r1", thread_id="t2", rating=-1, user_id="user-2")
results = await repo.list_by_run("t1", "r1", user_id=None)
assert len(results) == 1
assert results[0]["thread_id"] == "t1"
await engine.dispose()
@pytest.mark.anyio
async def test_list_by_run_respects_limit(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u1")
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u3")
results = await repo.list_by_run("t1", "r1", user_id=None, limit=2)
assert len(results) == 2
await engine.dispose()
@pytest.mark.anyio
async def test_list_by_thread(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1)
await repo.create(run_id="r2", thread_id="t1", rating=-1)
await repo.create(run_id="r3", thread_id="t2", rating=1)
results = await repo.list_by_thread("t1")
assert len(results) == 2
assert all(r["thread_id"] == "t1" for r in results)
await engine.dispose()
@pytest.mark.anyio
async def test_list_by_thread_respects_limit(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1)
await repo.create(run_id="r2", thread_id="t1", rating=-1)
await repo.create(run_id="r3", thread_id="t1", rating=1)
results = await repo.list_by_thread("t1", limit=2)
assert len(results) == 2
await engine.dispose()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
deleted = await repo.delete(created["feedback_id"])
assert deleted is True
assert await repo.get(created["feedback_id"]) is None
await engine.dispose()
@pytest.mark.anyio
async def test_delete_nonexistent(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
deleted = await repo.delete("nonexistent")
assert deleted is False
await engine.dispose()
@pytest.mark.anyio
async def test_aggregate_by_run(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
stats = await repo.aggregate_by_run("t1", "r1")
assert stats["total"] == 3
assert stats["positive"] == 2
assert stats["negative"] == 1
assert stats["run_id"] == "r1"
await engine.dispose()
@pytest.mark.anyio
async def test_aggregate_empty(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
stats = await repo.aggregate_by_run("t1", "r1")
assert stats["total"] == 0
assert stats["positive"] == 0
assert stats["negative"] == 0
await engine.dispose()
@pytest.mark.anyio
async def test_upsert_creates_new(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
assert record["rating"] == 1
assert record["feedback_id"]
assert record["user_id"] == "u1"
await engine.dispose()
@pytest.mark.anyio
async def test_upsert_updates_existing(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
assert second["feedback_id"] == first["feedback_id"]
assert second["rating"] == -1
assert second["comment"] == "changed my mind"
await engine.dispose()
@pytest.mark.anyio
async def test_upsert_different_users_separate(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
assert r1["feedback_id"] != r2["feedback_id"]
assert r1["rating"] == 1
assert r2["rating"] == -1
await engine.dispose()
@pytest.mark.anyio
async def test_upsert_invalid_rating(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
with pytest.raises(ValueError):
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
await engine.dispose()
@pytest.mark.anyio
async def test_delete_by_run(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
assert deleted is True
results = await repo.list_by_run("t1", "r1", user_id="u1")
assert len(results) == 0
await engine.dispose()
@pytest.mark.anyio
async def test_delete_by_run_nonexistent(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
assert deleted is False
await engine.dispose()
@pytest.mark.anyio
async def test_list_by_thread_grouped(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
assert "r1" in grouped
assert "r2" in grouped
assert "r3" not in grouped
assert grouped["r1"]["rating"] == 1
assert grouped["r2"]["rating"] == -1
await engine.dispose()
@pytest.mark.anyio
async def test_list_by_thread_grouped_filters_by_user_when_same_run_id_exists(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u1", comment="mine")
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="u2", comment="other")
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
assert grouped["r1"]["user_id"] == "u1"
assert grouped["r1"]["comment"] == "mine"
await engine.dispose()
@pytest.mark.anyio
async def test_list_by_thread_grouped_empty(self, tmp_path):
engine, _, repo = await _make_feedback_repo(tmp_path)
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
assert grouped == {}
await engine.dispose()
# -- Follow-up association --
class TestFollowUpAssociation:
@pytest.mark.anyio
async def test_run_records_follow_up_via_memory_store(self):
"""RunStoreAdapter persists follow_up_to_run_id as a first-class field."""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False)
store = RunStoreAdapter(session_factory)
await store.create("r1", thread_id="t1", status="success")
await store.create("r2", thread_id="t1", follow_up_to_run_id="r1")
run = await store.get("r2")
assert run is not None
assert run["follow_up_to_run_id"] == "r1"
await engine.dispose()
@pytest.mark.anyio
async def test_human_message_has_follow_up_metadata(self):
"""AppRunEventStore preserves follow_up_to_run_id in message metadata."""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False)
event_store = AppRunEventStore(session_factory)
await event_store.put_batch([
{
"thread_id": "t1",
"run_id": "r2",
"event_type": "human_message",
"category": "message",
"content": "Tell me more about that",
"metadata": {"follow_up_to_run_id": "r1"},
}
])
messages = await event_store.list_messages("t1")
assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1"
await engine.dispose()
@pytest.mark.anyio
async def test_follow_up_auto_detection_logic(self):
"""Simulate the auto-detection: latest successful run becomes follow_up_to."""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False)
store = RunStoreAdapter(session_factory)
await store.create("r1", thread_id="t1", status="success")
await store.create("r2", thread_id="t1", status="error")
# Auto-detect: list_by_thread returns newest first
recent = await store.list_by_thread("t1", limit=1)
follow_up = None
if recent and recent[0].get("status") == "success":
follow_up = recent[0]["run_id"]
# r2 (error) is newest, so no follow_up detected
assert follow_up is None
# Now add a successful run
await store.create("r3", thread_id="t1", status="success")
recent = await store.list_by_thread("t1", limit=1)
follow_up = None
if recent and recent[0].get("status") == "success":
follow_up = recent[0]["run_id"]
assert follow_up == "r3"
await engine.dispose()

View File

@ -0,0 +1,281 @@
"""Tests for the current runs service modules."""
from __future__ import annotations
import json
from app.gateway.routers.langgraph.runs import RunCreateRequest, format_sse
from app.gateway.services.runs.facade_factory import resolve_agent_factory
from app.gateway.services.runs.input.request_adapter import (
adapt_create_run_request,
adapt_create_stream_request,
adapt_create_wait_request,
adapt_join_stream_request,
adapt_join_wait_request,
)
from app.gateway.services.runs.input.spec_builder import RunSpecBuilder
def _builder() -> RunSpecBuilder:
return RunSpecBuilder()
def _build_runnable_config(
thread_id: str,
request_config: dict | None,
metadata: dict | None,
*,
assistant_id: str | None = None,
context: dict | None = None,
):
return _builder()._build_runnable_config( # noqa: SLF001 - intentional unit coverage
thread_id=thread_id,
request_config=request_config,
metadata=metadata,
assistant_id=assistant_id,
context=context,
)
def test_format_sse_basic():
frame = format_sse("metadata", {"run_id": "abc"})
assert frame.startswith("event: metadata\n")
assert "data: " in frame
parsed = json.loads(frame.split("data: ")[1].split("\n")[0])
assert parsed["run_id"] == "abc"
def test_format_sse_with_event_id():
frame = format_sse("metadata", {"run_id": "abc"}, event_id="123-0")
assert "id: 123-0" in frame
def test_format_sse_end_event_null():
frame = format_sse("end", None)
assert "data: null" in frame
def test_format_sse_no_event_id():
frame = format_sse("values", {"x": 1})
assert "id:" not in frame
def test_normalize_stream_modes_none():
assert _builder()._normalize_stream_modes(None) == ["values", "messages"] # noqa: SLF001
def test_normalize_stream_modes_string():
assert _builder()._normalize_stream_modes("messages-tuple") == ["messages"] # noqa: SLF001
def test_normalize_stream_modes_list():
assert _builder()._normalize_stream_modes(["values", "messages-tuple"]) == ["values", "messages"] # noqa: SLF001
def test_normalize_stream_modes_empty_list():
assert _builder()._normalize_stream_modes([]) == [] # noqa: SLF001
def test_normalize_input_none():
assert _builder()._normalize_input(None) is None # noqa: SLF001
def test_normalize_input_with_messages():
result = _builder()._normalize_input({"messages": [{"role": "user", "content": "hi"}]}) # noqa: SLF001
assert len(result["messages"]) == 1
assert result["messages"][0].content == "hi"
def test_normalize_input_passthrough():
result = _builder()._normalize_input({"custom_key": "value"}) # noqa: SLF001
assert result == {"custom_key": "value"}
def test_build_runnable_config_basic():
config = _build_runnable_config("thread-1", None, None)
assert config["configurable"]["thread_id"] == "thread-1"
assert config["recursion_limit"] == 100
def test_build_runnable_config_with_overrides():
config = _build_runnable_config(
"thread-1",
{"configurable": {"model_name": "gpt-4"}, "tags": ["test"]},
{"user": "alice"},
)
assert config["configurable"]["model_name"] == "gpt-4"
assert config["tags"] == ["test"]
assert config["metadata"]["user"] == "alice"
def test_build_runnable_config_custom_agent_injects_agent_name():
config = _build_runnable_config("thread-1", None, None, assistant_id="finalis")
assert config["configurable"]["agent_name"] == "finalis"
def test_build_runnable_config_lead_agent_no_agent_name():
config = _build_runnable_config("thread-1", None, None, assistant_id="lead_agent")
assert "agent_name" not in config["configurable"]
def test_build_runnable_config_none_assistant_id_no_agent_name():
config = _build_runnable_config("thread-1", None, None, assistant_id=None)
assert "agent_name" not in config["configurable"]
def test_build_runnable_config_explicit_agent_name_not_overwritten():
config = _build_runnable_config(
"thread-1",
{"configurable": {"agent_name": "explicit-agent"}},
None,
assistant_id="other-agent",
)
assert config["configurable"]["agent_name"] == "explicit-agent"
def test_resolve_agent_factory_returns_make_lead_agent():
from deerflow.agents.lead_agent.agent import make_lead_agent
assert resolve_agent_factory(None) is make_lead_agent
assert resolve_agent_factory("lead_agent") is make_lead_agent
assert resolve_agent_factory("finalis") is make_lead_agent
assert resolve_agent_factory("custom-agent-123") is make_lead_agent
def test_run_create_request_accepts_context():
body = RunCreateRequest(
input={"messages": [{"role": "user", "content": "hi"}]},
context={
"model_name": "deepseek-v3",
"thinking_enabled": True,
"is_plan_mode": True,
"subagent_enabled": True,
"thread_id": "some-thread-id",
},
)
assert body.context is not None
assert body.context["model_name"] == "deepseek-v3"
assert body.context["is_plan_mode"] is True
assert body.context["subagent_enabled"] is True
def test_run_create_request_context_defaults_to_none():
body = RunCreateRequest(input=None)
assert body.context is None
def test_context_merges_into_configurable():
config = _build_runnable_config(
"thread-1",
None,
None,
context={
"model_name": "deepseek-v3",
"mode": "ultra",
"reasoning_effort": "high",
"thinking_enabled": True,
"is_plan_mode": True,
"subagent_enabled": True,
"max_concurrent_subagents": 5,
"thread_id": "should-be-ignored",
},
)
assert config["configurable"]["model_name"] == "deepseek-v3"
assert config["configurable"]["thinking_enabled"] is True
assert config["configurable"]["is_plan_mode"] is True
assert config["configurable"]["subagent_enabled"] is True
assert config["configurable"]["max_concurrent_subagents"] == 5
assert config["configurable"]["reasoning_effort"] == "high"
assert config["configurable"]["mode"] == "ultra"
assert config["configurable"]["thread_id"] == "thread-1"
def test_context_does_not_override_existing_configurable():
config = _build_runnable_config(
"thread-1",
{"configurable": {"model_name": "gpt-4", "is_plan_mode": False}},
None,
context={
"model_name": "deepseek-v3",
"is_plan_mode": True,
"subagent_enabled": True,
},
)
assert config["configurable"]["model_name"] == "gpt-4"
assert config["configurable"]["is_plan_mode"] is False
assert config["configurable"]["subagent_enabled"] is True
def test_build_runnable_config_with_context_wrapper_in_request_config():
config = _build_runnable_config(
"thread-1",
{"context": {"user_id": "u-42", "thread_id": "thread-1"}},
None,
)
assert "context" in config
assert config["context"]["user_id"] == "u-42"
assert "configurable" not in config
assert config["recursion_limit"] == 100
def test_build_runnable_config_context_plus_configurable_prefers_context():
config = _build_runnable_config(
"thread-1",
{
"context": {"user_id": "u-42"},
"configurable": {"model_name": "gpt-4"},
},
None,
)
assert "context" in config
assert config["context"]["user_id"] == "u-42"
assert "configurable" not in config
def test_build_runnable_config_context_passthrough_other_keys():
config = _build_runnable_config(
"thread-1",
{"context": {"thread_id": "thread-1"}, "tags": ["prod"]},
None,
)
assert config["context"]["thread_id"] == "thread-1"
assert "configurable" not in config
assert config["tags"] == ["prod"]
def test_build_runnable_config_no_request_config():
config = _build_runnable_config("thread-abc", None, None)
assert config["configurable"] == {"thread_id": "thread-abc"}
assert "context" not in config
def test_request_adapter_create_background():
adapted = adapt_create_run_request(thread_id="thread-1", body={"input": {"x": 1}})
assert adapted.intent == "create_background"
assert adapted.thread_id == "thread-1"
assert adapted.run_id is None
def test_request_adapter_create_stream():
adapted = adapt_create_stream_request(thread_id=None, body={"input": {"x": 1}})
assert adapted.intent == "create_and_stream"
assert adapted.thread_id is None
assert adapted.is_stateless is True
def test_request_adapter_create_wait():
adapted = adapt_create_wait_request(thread_id="thread-1", body={})
assert adapted.intent == "create_and_wait"
assert adapted.thread_id == "thread-1"
def test_request_adapter_join_stream():
adapted = adapt_join_stream_request(thread_id="thread-1", run_id="run-1", headers={"Last-Event-ID": "123"})
assert adapted.intent == "join_stream"
assert adapted.last_event_id == "123"
def test_request_adapter_join_wait():
adapted = adapt_join_wait_request(thread_id="thread-1", run_id="run-1")
assert adapted.intent == "join_wait"
assert adapted.run_id == "run-1"

View File

@ -5,7 +5,6 @@ initialized, password strength validation,
and public accessibility (no auth cookie required).
"""
import asyncio
import os
import pytest
@ -13,41 +12,32 @@ from fastapi.testclient import TestClient
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-initialize-admin-min-32")
from app.gateway.auth.config import AuthConfig, set_auth_config
from store.config.app_config import AppConfig, set_app_config
from store.config.storage_config import StorageConfig
from app.plugins.auth.domain.config import AuthConfig
from app.plugins.auth.runtime.config_state import set_auth_config
_TEST_SECRET = "test-secret-key-initialize-admin-min-32"
@pytest.fixture(autouse=True)
def _setup_auth(tmp_path):
"""Fresh SQLite engine + auth config per test."""
from app.gateway import deps
from deerflow.persistence.engine import close_engine, init_engine
"""Fresh SQLite app config + auth config per test."""
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
url = f"sqlite+aiosqlite:///{tmp_path}/init_admin.db"
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
deps._cached_local_provider = None
deps._cached_repo = None
try:
yield
finally:
deps._cached_local_provider = None
deps._cached_repo = None
asyncio.run(close_engine())
set_app_config(AppConfig(storage=StorageConfig(driver="sqlite", sqlite_dir=str(tmp_path))))
yield
@pytest.fixture()
def client(_setup_auth):
from app.gateway.app import create_app
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.plugins.auth.domain.config import AuthConfig
from app.plugins.auth.runtime.config_state import set_auth_config
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
app = create_app()
# Do NOT use TestClient as a context manager — that would trigger the
# full lifespan which requires config.yaml. The auth endpoints work
# without the lifespan (persistence engine is set up by _setup_auth).
yield TestClient(app)
with TestClient(app) as test_client:
yield test_client
def _init_payload(**extra):
@ -110,11 +100,7 @@ def test_initialize_register_does_not_block_initialization(client):
def test_initialize_accessible_without_cookie(client):
"""No access_token cookie needed for /initialize."""
resp = client.post(
"/api/v1/auth/initialize",
json=_init_payload(),
cookies={},
)
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
assert resp.status_code == 201

View File

@ -152,7 +152,7 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path):
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
from deerflow.config import paths as paths_module
from deerflow.runtime import user_context as uc_module
from deerflow.runtime import actor_context as uc_module
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
@ -312,7 +312,7 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
from deerflow.config import paths as paths_module
from deerflow.runtime import user_context as uc_module
from deerflow.runtime import actor_context as uc_module
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)

View File

@ -9,19 +9,24 @@ import os
from datetime import timedelta
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
from langgraph_sdk import Auth
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.jwt import create_access_token, decode_token
from app.gateway.auth.models import User
from app.gateway.langgraph_auth import add_owner_filter, authenticate
from app.plugins.auth.domain.config import AuthConfig
from app.plugins.auth.domain.jwt import create_access_token, decode_token
from app.plugins.auth.security.langgraph import add_owner_filter, authenticate
from app.plugins.auth.domain.models import User as AuthUser
from app.plugins.auth.runtime.config_state import set_auth_config
from app.plugins.auth.storage import DbUserRepository, UserCreate
from store.persistence import MappedBase
from app.plugins.auth.storage.models import User as UserModel # noqa: F401
# ── Helpers ───────────────────────────────────────────────────────────────
@ -40,13 +45,40 @@ def _req(cookies=None, method="GET", headers=None):
def _user(user_id=None, token_version=0):
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
return AuthUser(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
def _mock_provider(user=None):
p = AsyncMock()
p.get_user = AsyncMock(return_value=user)
return p
async def _attach_auth_session(request, tmp_path, user: AuthUser | None = None):
engine = create_async_engine(
f"sqlite+aiosqlite:///{tmp_path / 'langgraph-auth.db'}",
future=True,
)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
session = session_factory()
if user is not None:
repo = DbUserRepository(session)
await repo.create_user(
UserCreate(
id=str(user.id),
email=user.email,
password_hash=user.password_hash,
system_role=user.system_role,
oauth_provider=user.oauth_provider,
oauth_id=user.oauth_id,
needs_setup=user.needs_setup,
token_version=user.token_version,
)
)
await session.commit()
request._auth_session = session
return engine, session
# ── @auth.authenticate ───────────────────────────────────────────────────
@ -73,77 +105,103 @@ def test_expired_jwt_raises_401():
assert exc.value.status_code == 401
def test_user_not_found_raises_401():
@pytest.mark.anyio
async def test_user_not_found_raises_401(tmp_path):
token = create_access_token("ghost")
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
request = _req({"access_token": token})
engine, session = await _attach_auth_session(request, tmp_path)
try:
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
await authenticate(request)
assert exc.value.status_code == 401
assert "User not found" in str(exc.value.detail)
finally:
await session.close()
await engine.dispose()
def test_token_version_mismatch_raises_401():
@pytest.mark.anyio
async def test_token_version_mismatch_raises_401(tmp_path):
user = _user(token_version=2)
token = create_access_token(str(user.id), token_version=1)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
request = _req({"access_token": token})
engine, session = await _attach_auth_session(request, tmp_path, user)
try:
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
await authenticate(request)
assert exc.value.status_code == 401
assert "revoked" in str(exc.value.detail).lower()
finally:
await session.close()
await engine.dispose()
def test_valid_token_returns_user_id():
@pytest.mark.anyio
async def test_valid_token_returns_user_id(tmp_path):
user = _user(token_version=0)
token = create_access_token(str(user.id), token_version=0)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": token})))
request = _req({"access_token": token})
engine, session = await _attach_auth_session(request, tmp_path, user)
try:
result = await authenticate(request)
finally:
await session.close()
await engine.dispose()
assert result == str(user.id)
def test_valid_token_matching_version():
@pytest.mark.anyio
async def test_valid_token_matching_version(tmp_path):
user = _user(token_version=5)
token = create_access_token(str(user.id), token_version=5)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": token})))
request = _req({"access_token": token})
engine, session = await _attach_auth_session(request, tmp_path, user)
try:
result = await authenticate(request)
finally:
await session.close()
await engine.dispose()
assert result == str(user.id)
# ── @auth.authenticate edge cases ────────────────────────────────────────
def test_provider_exception_propagates():
"""Provider raises → should not be swallowed silently."""
token = create_access_token("user-1")
p = AsyncMock()
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
with pytest.raises(RuntimeError, match="DB down"):
asyncio.run(authenticate(_req({"access_token": token})))
def test_jwt_missing_ver_defaults_to_zero():
@pytest.mark.anyio
async def test_jwt_missing_ver_defaults_to_zero(tmp_path):
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
import jwt as pyjwt
uid = str(uuid4())
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
user = _user(user_id=uid, token_version=0)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": raw})))
request = _req({"access_token": raw})
engine, session = await _attach_auth_session(request, tmp_path, user)
try:
result = await authenticate(request)
finally:
await session.close()
await engine.dispose()
assert result == uid
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
@pytest.mark.anyio
async def test_jwt_missing_ver_rejected_when_user_version_nonzero(tmp_path):
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
import jwt as pyjwt
uid = str(uuid4())
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
user = _user(user_id=uid, token_version=1)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
request = _req({"access_token": raw})
engine, session = await _attach_auth_session(request, tmp_path, user)
try:
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": raw})))
await authenticate(request)
assert exc.value.status_code == 401
finally:
await session.close()
await engine.dispose()
def test_wrong_secret_raises_401():
@ -223,7 +281,7 @@ def test_filter_with_empty_metadata():
def test_shared_jwt_secret():
token = create_access_token("user-1", token_version=3)
payload = decode_token(token)
from app.gateway.auth.errors import TokenError
from app.plugins.auth.domain.errors import TokenError
assert not isinstance(payload, TokenError)
assert payload.sub == "user-1"
@ -233,13 +291,13 @@ def test_shared_jwt_secret():
def test_langgraph_json_has_auth_path():
import json
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
assert "auth" in config
assert "langgraph_auth" in config["auth"]["path"]
config = json.loads((Path(__file__).resolve().parents[2] / "langgraph.json").read_text())
assert "graphs" in config
assert "lead_agent" in config["graphs"]
def test_auth_handler_has_both_layers():
from app.gateway.langgraph_auth import auth
from app.plugins.auth.security.langgraph import auth
assert auth._authenticate_handler is not None
assert len(auth._global_handlers) == 1

View File

@ -0,0 +1,236 @@
"""Cross-user isolation tests for current app-owned storage adapters.
These tests exercise isolation by binding different ``ActorContext``
values around the app-layer storage adapters. The safety property is:
data written under user A is not visible to user B through the same
adapter surface unless a call explicitly opts out with ``user_id=None``.
"""
from __future__ import annotations
from contextlib import contextmanager
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.infra.storage import AppRunEventStore, FeedbackStoreAdapter, RunStoreAdapter, ThreadMetaStorage, ThreadMetaStoreAdapter
from deerflow.runtime.actor_context import AUTO, ActorContext, bind_actor_context, reset_actor_context
from store.persistence import MappedBase
USER_A = "user-a"
USER_B = "user-b"
async def _make_components(tmp_path):
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'isolation.db'}", future=True)
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
thread_store = ThreadMetaStorage(ThreadMetaStoreAdapter(session_factory))
return (
engine,
thread_store,
RunStoreAdapter(session_factory),
FeedbackStoreAdapter(session_factory),
AppRunEventStore(session_factory),
)
@contextmanager
def _as_user(user_id: str):
token = bind_actor_context(ActorContext(user_id=user_id))
try:
yield
finally:
reset_actor_context(token)
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_thread_meta_cross_user_isolation(tmp_path):
engine, thread_store, _, _, _ = await _make_components(tmp_path)
try:
with _as_user(USER_A):
await thread_store.ensure_thread(thread_id="t-alpha")
with _as_user(USER_B):
await thread_store.ensure_thread(thread_id="t-beta")
with _as_user(USER_A):
assert (await thread_store.get_thread("t-alpha")) is not None
assert await thread_store.get_thread("t-beta") is None
rows = await thread_store.search_threads()
assert [row.thread_id for row in rows] == ["t-alpha"]
with _as_user(USER_B):
assert (await thread_store.get_thread("t-beta")) is not None
assert await thread_store.get_thread("t-alpha") is None
rows = await thread_store.search_threads()
assert [row.thread_id for row in rows] == ["t-beta"]
finally:
await engine.dispose()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_runs_cross_user_isolation(tmp_path):
engine, thread_store, run_store, _, _ = await _make_components(tmp_path)
try:
with _as_user(USER_A):
await thread_store.ensure_thread(thread_id="t-alpha")
await run_store.create("run-a1", "t-alpha")
await run_store.create("run-a2", "t-alpha")
with _as_user(USER_B):
await thread_store.ensure_thread(thread_id="t-beta")
await run_store.create("run-b1", "t-beta")
with _as_user(USER_A):
assert (await run_store.get("run-a1")) is not None
assert await run_store.get("run-b1") is None
rows = await run_store.list_by_thread("t-alpha")
assert {row["run_id"] for row in rows} == {"run-a1", "run-a2"}
assert await run_store.list_by_thread("t-beta") == []
with _as_user(USER_B):
assert await run_store.get("run-a1") is None
rows = await run_store.list_by_thread("t-beta")
assert [row["run_id"] for row in rows] == ["run-b1"]
finally:
await engine.dispose()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_run_events_cross_user_isolation(tmp_path):
engine, thread_store, _, _, event_store = await _make_components(tmp_path)
try:
with _as_user(USER_A):
await thread_store.ensure_thread(thread_id="t-alpha")
await event_store.put_batch(
[
{
"thread_id": "t-alpha",
"run_id": "run-a1",
"event_type": "human_message",
"category": "message",
"content": "User A private question",
},
{
"thread_id": "t-alpha",
"run_id": "run-a1",
"event_type": "ai_message",
"category": "message",
"content": "User A private answer",
},
]
)
with _as_user(USER_B):
await thread_store.ensure_thread(thread_id="t-beta")
await event_store.put_batch(
[
{
"thread_id": "t-beta",
"run_id": "run-b1",
"event_type": "human_message",
"category": "message",
"content": "User B private question",
}
]
)
with _as_user(USER_A):
msgs = await event_store.list_messages("t-alpha")
contents = [msg["content"] for msg in msgs]
assert "User A private question" in contents
assert "User A private answer" in contents
assert "User B private question" not in contents
assert await event_store.list_messages("t-beta") == []
assert await event_store.list_events("t-beta", "run-b1") == []
assert await event_store.count_messages("t-beta") == 0
with _as_user(USER_B):
msgs = await event_store.list_messages("t-beta")
contents = [msg["content"] for msg in msgs]
assert "User B private question" in contents
assert "User A private question" not in contents
assert await event_store.count_messages("t-alpha") == 0
finally:
await engine.dispose()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_feedback_cross_user_isolation(tmp_path):
engine, thread_store, _, feedback_store, _ = await _make_components(tmp_path)
try:
with _as_user(USER_A):
await thread_store.ensure_thread(thread_id="t-alpha")
a_feedback = await feedback_store.create(
run_id="run-a1",
thread_id="t-alpha",
rating=1,
user_id=USER_A,
comment="A liked this",
)
with _as_user(USER_B):
await thread_store.ensure_thread(thread_id="t-beta")
b_feedback = await feedback_store.create(
run_id="run-b1",
thread_id="t-beta",
rating=-1,
user_id=USER_B,
comment="B disliked this",
)
with _as_user(USER_A):
assert (await feedback_store.get(a_feedback["feedback_id"])) is not None
assert await feedback_store.get(b_feedback["feedback_id"]) is not None
assert await feedback_store.list_by_run("t-beta", "run-b1", user_id=USER_A) == []
with _as_user(USER_B):
assert await feedback_store.list_by_run("t-alpha", "run-a1", user_id=USER_B) == []
rows = await feedback_store.list_by_run("t-beta", "run-b1", user_id=USER_B)
assert len(rows) == 1
assert rows[0]["comment"] == "B disliked this"
finally:
await engine.dispose()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_repository_without_context_raises(tmp_path):
engine, thread_store, _, _, _ = await _make_components(tmp_path)
try:
with pytest.raises(RuntimeError, match="no actor context is set"):
await thread_store.search_threads(user_id=AUTO)
finally:
await engine.dispose()
@pytest.mark.anyio
@pytest.mark.no_auto_user
async def test_explicit_none_bypasses_filter(tmp_path):
engine, thread_store, _, _, _ = await _make_components(tmp_path)
try:
with _as_user(USER_A):
await thread_store.ensure_thread(thread_id="t-alpha")
with _as_user(USER_B):
await thread_store.ensure_thread(thread_id="t-beta")
rows = await thread_store.search_threads(user_id=None)
assert {row.thread_id for row in rows} == {"t-alpha", "t-beta"}
assert await thread_store.get_thread("t-alpha", user_id=None) is not None
assert await thread_store.get_thread("t-beta", user_id=None) is not None
finally:
await engine.dispose()

View File

@ -0,0 +1,66 @@
from __future__ import annotations
from langchain_core.messages import HumanMessage
from deerflow.runtime.runs.callbacks.builder import build_run_callbacks
from deerflow.runtime.runs.types import RunRecord, RunStatus
def _record() -> RunRecord:
return RunRecord(
run_id="run-1",
thread_id="thread-1",
assistant_id=None,
status=RunStatus.pending,
temporary=False,
multitask_strategy="reject",
metadata={},
created_at="",
updated_at="",
)
def test_build_run_callbacks_sets_first_human_message_from_string_content():
artifacts = build_run_callbacks(
record=_record(),
graph_input={"messages": [HumanMessage(content="hello world")]},
event_store=None,
)
assert artifacts.completion_data().first_human_message == "hello world"
def test_build_run_callbacks_sets_first_human_message_from_content_blocks():
artifacts = build_run_callbacks(
record=_record(),
graph_input={
"messages": [
HumanMessage(
content=[
{"type": "text", "text": "hello "},
{"type": "text", "text": "world"},
]
)
]
},
event_store=None,
)
assert artifacts.completion_data().first_human_message == "hello world"
def test_build_run_callbacks_sets_first_human_message_from_dict_payload():
artifacts = build_run_callbacks(
record=_record(),
graph_input={
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "hi from dict"}],
}
]
},
event_store=None,
)
assert artifacts.completion_data().first_human_message == "hi from dict"

Some files were not shown because too many files have changed in this diff Show More