mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(tests): reorganize tests into unittest/ and e2e/ directories
- Move all unit tests from tests/ to tests/unittest/ - Add tests/e2e/ directory for end-to-end tests - Update conftest.py for new test structure - Add new tests for auth dependencies, policies, route injection - Add new tests for run callbacks, create store, execution artifacts - Remove obsolete tests for deleted persistence layer Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
38a6ec496f
commit
2fe0856e33
@ -1,4 +1,4 @@
|
||||
"""Test configuration for the backend test suite.
|
||||
"""Test configuration shared by unit and end-to-end tests.
|
||||
|
||||
Sets up sys.path and pre-mocks modules that would cause circular import
|
||||
issues when unit-testing lightweight config/registry code in isolation.
|
||||
@ -6,8 +6,8 @@ issues when unit-testing lightweight config/registry code in isolation.
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -38,14 +38,51 @@ _executor_mock.get_background_task_result = MagicMock()
|
||||
sys.modules["deerflow.subagents.executor"] = _executor_mock
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"no_auto_user: disable the conftest autouse contextvar fixture for this test",
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"Pydantic serializer warnings:.*field_name='context'.*",
|
||||
category=UserWarning,
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"pkg_resources is deprecated as an API.*",
|
||||
category=UserWarning,
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"Deprecated call to `pkg_resources\.declare_namespace\(.*",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"datetime\.datetime\.utcfromtimestamp\(\) is deprecated.*",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"websockets\.InvalidStatusCode is deprecated",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"websockets\.legacy is deprecated.*",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"websockets\.client\.WebSocketClientProtocol is deprecated",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def provisioner_module():
|
||||
"""Load docker/provisioner/app.py as an importable test module.
|
||||
|
||||
Shared by test_provisioner_kubeconfig and test_provisioner_pvc_volumes so
|
||||
that any change to the provisioner entry-point path or module name only
|
||||
needs to be updated in one place.
|
||||
"""
|
||||
"""Load docker/provisioner/app.py as an importable test module."""
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
module_path = repo_root / "docker" / "provisioner" / "app.py"
|
||||
spec = importlib.util.spec_from_file_location("provisioner_app_test", module_path)
|
||||
@ -56,42 +93,25 @@ def provisioner_module():
|
||||
return module
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods read ``user_id`` from a contextvar by default
|
||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||
# pre-existing persistence test would raise RuntimeError because the
|
||||
# contextvar is unset. The fixture sets a default test user on every
|
||||
# test; tests that explicitly want to verify behaviour *without* a user
|
||||
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _auto_user_context(request):
|
||||
"""Inject a default ``test-user-autouse`` into the contextvar.
|
||||
|
||||
Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that
|
||||
tests which don't touch the persistence layer never pay the cost
|
||||
of importing runtime.user_context.
|
||||
"""
|
||||
"""Inject a default ``test-user-autouse`` into the contextvar."""
|
||||
if request.node.get_closest_marker("no_auto_user"):
|
||||
yield
|
||||
return
|
||||
|
||||
try:
|
||||
from deerflow.runtime.user_context import (
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
from deerflow.runtime.actor_context import (
|
||||
ActorContext,
|
||||
bind_actor_context,
|
||||
reset_actor_context,
|
||||
)
|
||||
except ImportError:
|
||||
yield
|
||||
return
|
||||
|
||||
user = SimpleNamespace(id="test-user-autouse", email="test@local")
|
||||
token = set_current_user(user)
|
||||
token = bind_actor_context(ActorContext(user_id="test-user-autouse"))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
reset_actor_context(token)
|
||||
|
||||
33
backend/tests/e2e/conftest.py
Normal file
33
backend/tests/e2e/conftest.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Shared fixtures for end-to-end API tests."""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.plugins.auth.api.schemas import _login_attempts
|
||||
from app.plugins.auth.domain.config import AuthConfig
|
||||
from app.plugins.auth.runtime.config_state import reset_auth_config, set_auth_config
|
||||
from store.config.app_config import AppConfig, reset_app_config, set_app_config
|
||||
from store.config.storage_config import StorageConfig
|
||||
|
||||
_TEST_SECRET = "test-secret-key-e2e-auth-minimum-32"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(tmp_path):
|
||||
"""Create a full app client backed by an isolated SQLite directory."""
|
||||
from app.gateway.app import create_app
|
||||
|
||||
_login_attempts.clear()
|
||||
reset_auth_config()
|
||||
reset_app_config()
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
set_app_config(AppConfig(storage=StorageConfig(driver="sqlite", sqlite_dir=str(tmp_path))))
|
||||
|
||||
app = create_app()
|
||||
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
_login_attempts.clear()
|
||||
reset_auth_config()
|
||||
reset_app_config()
|
||||
163
backend/tests/e2e/test_auth_initialize_me.py
Normal file
163
backend/tests/e2e/test_auth_initialize_me.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""End-to-end auth API tests for the main auth user journeys."""
|
||||
|
||||
from app.plugins.auth.security.csrf import CSRF_HEADER_NAME
|
||||
|
||||
|
||||
def _initialize_payload(**overrides):
|
||||
return {
|
||||
"email": "admin@example.com",
|
||||
"password": "Str0ng!Pass99",
|
||||
**overrides,
|
||||
}
|
||||
|
||||
|
||||
def _register_payload(**overrides):
|
||||
return {
|
||||
"email": "user@example.com",
|
||||
"password": "Str0ng!Pass99",
|
||||
**overrides,
|
||||
}
|
||||
|
||||
|
||||
def _login(client, *, email="user@example.com", password="Str0ng!Pass99"):
|
||||
return client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": password},
|
||||
)
|
||||
|
||||
|
||||
def _csrf_headers(client) -> dict[str, str]:
|
||||
token = client.cookies.get("csrf_token")
|
||||
assert token, "csrf_token cookie is required before calling protected POST endpoints"
|
||||
return {CSRF_HEADER_NAME: token}
|
||||
|
||||
|
||||
def test_initialize_returns_admin_and_sets_session_cookie(client):
|
||||
response = client.post("/api/v1/auth/initialize", json=_initialize_payload())
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["email"] == "admin@example.com"
|
||||
assert response.json()["system_role"] == "admin"
|
||||
assert "access_token" in response.cookies
|
||||
assert "access_token" in client.cookies
|
||||
|
||||
|
||||
def test_me_returns_initialized_admin_identity(client):
|
||||
initialize = client.post("/api/v1/auth/initialize", json=_initialize_payload())
|
||||
assert initialize.status_code == 201
|
||||
|
||||
response = client.get("/api/v1/auth/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": response.json()["id"],
|
||||
"email": "admin@example.com",
|
||||
"system_role": "admin",
|
||||
"needs_setup": False,
|
||||
}
|
||||
|
||||
|
||||
def test_setup_status_flips_after_initialize(client):
|
||||
before = client.get("/api/v1/auth/setup-status")
|
||||
assert before.status_code == 200
|
||||
assert before.json() == {"needs_setup": True}
|
||||
|
||||
initialize = client.post("/api/v1/auth/initialize", json=_initialize_payload())
|
||||
assert initialize.status_code == 201
|
||||
|
||||
after = client.get("/api/v1/auth/setup-status")
|
||||
assert after.status_code == 200
|
||||
assert after.json() == {"needs_setup": False}
|
||||
|
||||
|
||||
def test_register_logs_in_user_and_me_returns_identity(client):
|
||||
response = client.post("/api/v1/auth/register", json=_register_payload())
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["email"] == "user@example.com"
|
||||
assert response.json()["system_role"] == "user"
|
||||
assert "access_token" in client.cookies
|
||||
assert "csrf_token" in client.cookies
|
||||
|
||||
me = client.get("/api/v1/auth/me")
|
||||
assert me.status_code == 200
|
||||
assert me.json()["email"] == "user@example.com"
|
||||
assert me.json()["system_role"] == "user"
|
||||
assert me.json()["needs_setup"] is False
|
||||
|
||||
|
||||
def test_me_requires_authentication(client):
|
||||
response = client.get("/api/v1/auth/me")
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_logout_clears_session_and_me_is_denied(client):
|
||||
register = client.post("/api/v1/auth/register", json=_register_payload())
|
||||
assert register.status_code == 201
|
||||
|
||||
logout = client.post("/api/v1/auth/logout")
|
||||
assert logout.status_code == 200
|
||||
assert logout.json() == {"message": "Successfully logged out"}
|
||||
|
||||
me = client.get("/api/v1/auth/me")
|
||||
assert me.status_code == 401
|
||||
assert me.json()["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_login_local_restores_session_after_logout(client):
|
||||
register = client.post("/api/v1/auth/register", json=_register_payload())
|
||||
assert register.status_code == 201
|
||||
assert client.post("/api/v1/auth/logout").status_code == 200
|
||||
|
||||
login = _login(client)
|
||||
assert login.status_code == 200
|
||||
assert login.json()["needs_setup"] is False
|
||||
assert "access_token" in client.cookies
|
||||
assert "csrf_token" in client.cookies
|
||||
|
||||
me = client.get("/api/v1/auth/me")
|
||||
assert me.status_code == 200
|
||||
assert me.json()["email"] == "user@example.com"
|
||||
|
||||
|
||||
def test_change_password_updates_credentials_and_rotates_login(client):
|
||||
register = client.post("/api/v1/auth/register", json=_register_payload())
|
||||
assert register.status_code == 201
|
||||
|
||||
change = client.post(
|
||||
"/api/v1/auth/change-password",
|
||||
json={
|
||||
"current_password": "Str0ng!Pass99",
|
||||
"new_password": "An0ther!Pass88",
|
||||
"new_email": "renamed@example.com",
|
||||
},
|
||||
headers=_csrf_headers(client),
|
||||
)
|
||||
assert change.status_code == 200
|
||||
assert change.json() == {"message": "Password changed successfully"}
|
||||
|
||||
assert client.post("/api/v1/auth/logout").status_code == 200
|
||||
|
||||
old_login = _login(client)
|
||||
assert old_login.status_code == 401
|
||||
assert old_login.json()["detail"]["code"] == "invalid_credentials"
|
||||
|
||||
new_login = _login(client, email="renamed@example.com", password="An0ther!Pass88")
|
||||
assert new_login.status_code == 200
|
||||
|
||||
me = client.get("/api/v1/auth/me")
|
||||
assert me.status_code == 200
|
||||
assert me.json()["email"] == "renamed@example.com"
|
||||
|
||||
|
||||
def test_oauth_endpoints_expose_current_placeholder_behavior(client):
|
||||
unsupported = client.get("/api/v1/auth/oauth/not-a-provider")
|
||||
assert unsupported.status_code == 400
|
||||
|
||||
github = client.get("/api/v1/auth/oauth/github")
|
||||
assert github.status_code == 501
|
||||
|
||||
callback = client.get("/api/v1/auth/callback/github", params={"code": "abc", "state": "xyz"})
|
||||
assert callback.status_code == 501
|
||||
@ -1,304 +0,0 @@
|
||||
"""Unit tests for checkpointer config and singleton factory."""
|
||||
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset singleton state before each test."""
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
yield
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckpointerConfig:
|
||||
def test_load_memory_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.type == "memory"
|
||||
assert config.connection_string is None
|
||||
|
||||
def test_load_sqlite_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.type == "sqlite"
|
||||
assert config.connection_string == "/tmp/test.db"
|
||||
|
||||
def test_load_postgres_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.type == "postgres"
|
||||
assert config.connection_string == "postgresql://localhost/db"
|
||||
|
||||
def test_default_connection_string_is_none(self):
|
||||
config = CheckpointerConfig(type="memory")
|
||||
assert config.connection_string is None
|
||||
|
||||
def test_set_config_to_none(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
set_checkpointer_config(None)
|
||||
assert get_checkpointer_config() is None
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
load_checkpointer_config_from_dict({"type": "unknown"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCheckpointer:
|
||||
def test_returns_in_memory_saver_when_not_configured(self):
|
||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
cp = get_checkpointer()
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_memory_returns_in_memory_saver(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
cp = get_checkpointer()
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
|
||||
def test_memory_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cp1 = get_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
assert cp1 is cp2
|
||||
|
||||
def test_reset_clears_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cp1 = get_checkpointer()
|
||||
reset_checkpointer()
|
||||
cp2 = get_checkpointer()
|
||||
assert cp1 is not cp2
|
||||
|
||||
def test_sqlite_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
|
||||
get_checkpointer()
|
||||
|
||||
def test_postgres_raises_when_package_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
|
||||
get_checkpointer()
|
||||
|
||||
def test_postgres_raises_when_connection_string_missing(self):
|
||||
load_checkpointer_config_from_dict({"type": "postgres"})
|
||||
mock_saver = MagicMock()
|
||||
mock_module = MagicMock()
|
||||
mock_module.PostgresSaver = mock_saver
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
|
||||
reset_checkpointer()
|
||||
with pytest.raises(ValueError, match="connection_string is required"):
|
||||
get_checkpointer()
|
||||
|
||||
def test_sqlite_creates_saver(self):
|
||||
"""SQLite checkpointer is created when package is available."""
|
||||
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_saver_cls = MagicMock()
|
||||
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.SqliteSaver = mock_saver_cls
|
||||
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once()
|
||||
mock_saver_instance.setup.assert_called_once()
|
||||
|
||||
def test_postgres_creates_saver(self):
|
||||
"""Postgres checkpointer is created when packages are available."""
|
||||
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
|
||||
|
||||
mock_saver_instance = MagicMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
|
||||
mock_cm.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_saver_cls = MagicMock()
|
||||
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
|
||||
|
||||
mock_pg_module = MagicMock()
|
||||
mock_pg_module.PostgresSaver = mock_saver_cls
|
||||
|
||||
with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
|
||||
reset_checkpointer()
|
||||
cp = get_checkpointer()
|
||||
|
||||
assert cp is mock_saver_instance
|
||||
mock_saver_cls.from_conn_string.assert_called_once_with("postgresql://localhost/db")
|
||||
mock_saver_instance.setup.assert_called_once()
|
||||
|
||||
|
||||
class TestAsyncCheckpointer:
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
||||
"""Async SQLite setup should move mkdir off the event loop."""
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
|
||||
|
||||
mock_saver = AsyncMock()
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__.return_value = mock_saver
|
||||
mock_cm.__aexit__.return_value = False
|
||||
|
||||
mock_saver_cls = MagicMock()
|
||||
mock_saver_cls.from_conn_string.return_value = mock_cm
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.AsyncSqliteSaver = mock_saver_cls
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
||||
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||
patch(
|
||||
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||
return_value="/tmp/resolved/test.db",
|
||||
),
|
||||
):
|
||||
async with make_checkpointer() as saver:
|
||||
assert saver is mock_saver
|
||||
|
||||
mock_to_thread.assert_awaited_once()
|
||||
called_fn, called_path = mock_to_thread.await_args.args
|
||||
assert called_fn.__name__ == "ensure_sqlite_parent_dir"
|
||||
assert called_path == "/tmp/resolved/test.db"
|
||||
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
|
||||
mock_saver.setup.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# app_config.py integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppConfigLoadsCheckpointer:
|
||||
def test_load_checkpointer_section(self):
|
||||
"""load_checkpointer_config_from_dict populates the global config."""
|
||||
set_checkpointer_config(None)
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
cfg = get_checkpointer_config()
|
||||
assert cfg is not None
|
||||
assert cfg.type == "memory"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeerFlowClient falls back to config checkpointer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClientCheckpointerFallback:
|
||||
def test_client_uses_config_checkpointer_when_none_provided(self):
|
||||
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return MagicMock()
|
||||
|
||||
model_mock = MagicMock()
|
||||
config_mock = MagicMock()
|
||||
config_mock.models = [model_mock]
|
||||
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
|
||||
config_mock.checkpointer = None
|
||||
|
||||
with (
|
||||
patch("deerflow.client.get_app_config", return_value=config_mock),
|
||||
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
|
||||
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value=""),
|
||||
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
|
||||
):
|
||||
client = DeerFlowClient(checkpointer=None)
|
||||
config = client._get_runnable_config("test-thread")
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert "checkpointer" in captured_kwargs
|
||||
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
|
||||
|
||||
def test_client_explicit_checkpointer_takes_precedence(self):
|
||||
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
explicit_cp = MagicMock()
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return MagicMock()
|
||||
|
||||
model_mock = MagicMock()
|
||||
config_mock = MagicMock()
|
||||
config_mock.models = [model_mock]
|
||||
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
|
||||
config_mock.checkpointer = None
|
||||
|
||||
with (
|
||||
patch("deerflow.client.get_app_config", return_value=config_mock),
|
||||
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
|
||||
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value=""),
|
||||
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
|
||||
):
|
||||
client = DeerFlowClient(checkpointer=explicit_cp)
|
||||
config = client._get_runnable_config("test-thread")
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert captured_kwargs["checkpointer"] is explicit_cp
|
||||
@ -1,55 +0,0 @@
|
||||
"""Test for issue #1016: checkpointer should not return None."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
|
||||
class TestCheckpointerNoneFix:
|
||||
"""Tests that checkpointer context managers return InMemorySaver instead of None."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
|
||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None and database=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
mock_config.database = None
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||
async with make_checkpointer() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
|
||||
# Should be able to call alist() without AttributeError
|
||||
# This is what LangGraph does and what was failing in issue #1016
|
||||
result = []
|
||||
async for item in checkpointer.alist(config={"configurable": {"thread_id": "test"}}):
|
||||
result.append(item)
|
||||
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
|
||||
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
||||
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.runtime.checkpointer.provider import checkpointer_context
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||
with checkpointer_context() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
assert isinstance(checkpointer, InMemorySaver)
|
||||
|
||||
# Should be able to call list() without AttributeError
|
||||
result = list(checkpointer.list(config={"configurable": {"thread_id": "test"}}))
|
||||
|
||||
# Empty list is expected for a fresh checkpointer
|
||||
assert result == []
|
||||
@ -1,296 +0,0 @@
|
||||
"""Tests for _ensure_admin_user() in app.py.
|
||||
|
||||
Covers: first-boot no-op (admin creation removed), orphan migration
|
||||
when admin exists, no-op on no admin found, and edge cases.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
|
||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _make_app_stub(store=None):
|
||||
"""Minimal app-like object with state.store."""
|
||||
app = SimpleNamespace()
|
||||
app.state = SimpleNamespace()
|
||||
app.state.store = store
|
||||
return app
|
||||
|
||||
|
||||
def _make_provider(admin_count=0):
|
||||
p = AsyncMock()
|
||||
p.count_users = AsyncMock(return_value=admin_count)
|
||||
p.count_admin_users = AsyncMock(return_value=admin_count)
|
||||
p.create_user = AsyncMock()
|
||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||
return p
|
||||
|
||||
|
||||
def _make_session_factory(admin_row=None):
|
||||
"""Build a mock async session factory that returns a row from execute()."""
|
||||
row_result = MagicMock()
|
||||
row_result.scalar_one_or_none.return_value = admin_row
|
||||
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = admin_row
|
||||
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock(return_value=execute_result)
|
||||
|
||||
# Async context manager
|
||||
session_cm = AsyncMock()
|
||||
session_cm.__aenter__ = AsyncMock(return_value=session)
|
||||
session_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
sf = MagicMock()
|
||||
sf.return_value = session_cm
|
||||
return sf
|
||||
|
||||
|
||||
# ── First boot: no admin → return early ──────────────────────────────────
|
||||
|
||||
|
||||
def test_first_boot_does_not_create_admin():
|
||||
"""admin_count==0 → do NOT create admin automatically."""
|
||||
provider = _make_provider(admin_count=0)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_not_called()
|
||||
|
||||
|
||||
def test_first_boot_skips_migration():
|
||||
"""No admin → return early before any migration attempt."""
|
||||
provider = _make_provider(admin_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
|
||||
|
||||
# ── Admin exists: migration runs when admin row found ────────────────────
|
||||
|
||||
|
||||
def test_admin_exists_triggers_migration():
|
||||
"""Admin exists and admin row found → _migrate_orphaned_threads called."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_exists_no_admin_row_skips_migration():
|
||||
"""Admin count > 0 but DB row missing (edge case) → skip migration gracefully."""
|
||||
provider = _make_provider(admin_count=2)
|
||||
sf = _make_session_factory(admin_row=None)
|
||||
store = AsyncMock()
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
|
||||
|
||||
def test_admin_exists_no_store_skips_migration():
|
||||
"""Admin exists, row found, but no store → no crash, no migration."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
app = _make_app_stub(store=None)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
# No assertion needed — just verify no crash
|
||||
|
||||
|
||||
def test_admin_exists_session_factory_none_skips_migration():
|
||||
"""get_session_factory() returns None → return early, no crash."""
|
||||
provider = _make_provider(admin_count=1)
|
||||
store = AsyncMock()
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=None):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
|
||||
|
||||
def test_migration_failure_is_non_fatal():
|
||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
# Should not raise
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
|
||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||
|
||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||
(no auth) accumulates threads in the LangGraph Store namespace
|
||||
``("threads",)`` with no ``metadata.user_id``. After upgrading to
|
||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||
rewrite each unowned item with the freshly created admin's id.
|
||||
"""
|
||||
from app.gateway.app import _migrate_orphaned_threads
|
||||
|
||||
# Three orphan items + one already-owned item that should be left alone.
|
||||
items = [
|
||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
|
||||
]
|
||||
store = AsyncMock()
|
||||
# asearch returns the entire batch on first call, then an empty page
|
||||
# to terminate _iter_store_items.
|
||||
store.asearch = AsyncMock(side_effect=[items, []])
|
||||
aput_calls: list[tuple[tuple, str, dict]] = []
|
||||
|
||||
async def _record_aput(namespace, key, value):
|
||||
aput_calls.append((namespace, key, value))
|
||||
|
||||
store.aput = AsyncMock(side_effect=_record_aput)
|
||||
|
||||
migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42"))
|
||||
|
||||
# Three orphan rows migrated, one preserved.
|
||||
assert migrated == 3
|
||||
assert len(aput_calls) == 3
|
||||
rewritten_keys = {call[1] for call in aput_calls}
|
||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||
# Each rewrite carries the new user_id; titles preserved where present.
|
||||
by_key = {call[1]: call[2] for call in aput_calls}
|
||||
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
|
||||
# The pre-owned item must NOT have been rewritten.
|
||||
assert "t4" not in rewritten_keys
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_empty_store_is_noop():
|
||||
"""A store with no threads → migrated == 0, no aput calls."""
|
||||
from app.gateway.app import _migrate_orphaned_threads
|
||||
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
store.aput = AsyncMock()
|
||||
|
||||
migrated = asyncio.run(_migrate_orphaned_threads(store, "admin-id-42"))
|
||||
|
||||
assert migrated == 0
|
||||
store.aput.assert_not_called()
|
||||
|
||||
|
||||
def test_iter_store_items_walks_multiple_pages():
|
||||
"""Cursor-style iterator pulls every page until a short page terminates.
|
||||
|
||||
Closes the regression where the old hardcoded ``limit=1000`` could
|
||||
silently drop orphans on a large pre-upgrade dataset. The migration
|
||||
code path uses the default ``page_size=500``; this test pins the
|
||||
iterator with ``page_size=2`` so it stays fast.
|
||||
"""
|
||||
from app.gateway.app import _iter_store_items
|
||||
|
||||
page_a = [SimpleNamespace(key=f"t{i}", value={"metadata": {}}) for i in range(2)]
|
||||
page_b = [SimpleNamespace(key=f"t{i + 2}", value={"metadata": {}}) for i in range(2)]
|
||||
page_c: list = [] # short page → loop terminates
|
||||
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=[page_a, page_b, page_c])
|
||||
|
||||
async def _collect():
|
||||
return [item.key async for item in _iter_store_items(store, ("threads",), page_size=2)]
|
||||
|
||||
keys = asyncio.run(_collect())
|
||||
assert keys == ["t0", "t1", "t2", "t3"]
|
||||
# Three asearch calls: full batch, full batch, empty terminator
|
||||
assert store.asearch.await_count == 3
|
||||
|
||||
|
||||
def test_iter_store_items_terminates_on_short_page():
|
||||
"""A short page (len < page_size) ends the loop without an extra call."""
|
||||
from app.gateway.app import _iter_store_items
|
||||
|
||||
page = [SimpleNamespace(key=f"t{i}", value={}) for i in range(3)]
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=page)
|
||||
|
||||
async def _collect():
|
||||
return [item.key async for item in _iter_store_items(store, ("threads",), page_size=10)]
|
||||
|
||||
keys = asyncio.run(_collect())
|
||||
assert keys == ["t0", "t1", "t2"]
|
||||
# Only one call — no terminator probe needed because len(batch) < page_size
|
||||
assert store.asearch.await_count == 1
|
||||
@ -1,289 +0,0 @@
|
||||
"""Tests for FeedbackRepository and follow-up association.
|
||||
|
||||
Uses temp SQLite DB for ORM tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
|
||||
async def _make_feedback_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return FeedbackRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- FeedbackRepository --
|
||||
|
||||
|
||||
class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_positive(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
assert record["feedback_id"]
|
||||
assert record["rating"] == 1
|
||||
assert record["run_id"] == "r1"
|
||||
assert record["thread_id"] == "t1"
|
||||
assert "created_at" in record
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_negative_with_comment(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(
|
||||
run_id="r1",
|
||||
thread_id="t1",
|
||||
rating=-1,
|
||||
comment="Response was inaccurate",
|
||||
)
|
||||
assert record["rating"] == -1
|
||||
assert record["comment"] == "Response was inaccurate"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_message_id(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42")
|
||||
assert record["message_id"] == "msg-42"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
assert record["user_id"] == "user-1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_invalid_rating_zero(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=0)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_invalid_rating_five(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=5)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
fetched = await repo.get(created["feedback_id"])
|
||||
assert fetched is not None
|
||||
assert fetched["feedback_id"] == created["feedback_id"]
|
||||
assert fetched["rating"] == 1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
assert await repo.get("nonexistent") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
|
||||
results = await repo.list_by_run("t1", "r1", user_id=None)
|
||||
assert len(results) == 2
|
||||
assert all(r["run_id"] == "r1" for r in results)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r3", thread_id="t2", rating=1)
|
||||
results = await repo.list_by_thread("t1")
|
||||
assert len(results) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in results)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
deleted = await repo.delete(created["feedback_id"])
|
||||
assert deleted is True
|
||||
assert await repo.get(created["feedback_id"]) is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete("nonexistent")
|
||||
assert deleted is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 3
|
||||
assert stats["positive"] == 2
|
||||
assert stats["negative"] == 1
|
||||
assert stats["run_id"] == "r1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_empty(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 0
|
||||
assert stats["positive"] == 0
|
||||
assert stats["negative"] == 0
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_creates_new(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
assert record["rating"] == 1
|
||||
assert record["feedback_id"]
|
||||
assert record["user_id"] == "u1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_updates_existing(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
|
||||
assert second["feedback_id"] == first["feedback_id"]
|
||||
assert second["rating"] == -1
|
||||
assert second["comment"] == "changed my mind"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_different_users_separate(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
|
||||
assert r1["feedback_id"] != r2["feedback_id"]
|
||||
assert r1["rating"] == 1
|
||||
assert r2["rating"] == -1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_invalid_rating(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is True
|
||||
results = await repo.list_by_run("t1", "r1", user_id="u1")
|
||||
assert len(results) == 0
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
|
||||
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert "r1" in grouped
|
||||
assert "r2" in grouped
|
||||
assert "r3" not in grouped
|
||||
assert grouped["r1"]["rating"] == 1
|
||||
assert grouped["r2"]["rating"] == -1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped_empty(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert grouped == {}
|
||||
await _cleanup()
|
||||
|
||||
|
||||
# -- Follow-up association --
|
||||
|
||||
|
||||
class TestFollowUpAssociation:
|
||||
@pytest.mark.anyio
|
||||
async def test_run_records_follow_up_via_memory_store(self):
|
||||
"""MemoryRunStore stores follow_up_to_run_id in kwargs."""
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
store = MemoryRunStore()
|
||||
await store.put("r1", thread_id="t1", status="success")
|
||||
# MemoryRunStore doesn't have follow_up_to_run_id as a top-level param,
|
||||
# but it can be passed via metadata
|
||||
await store.put("r2", thread_id="t1", metadata={"follow_up_to_run_id": "r1"})
|
||||
run = await store.get("r2")
|
||||
assert run["metadata"]["follow_up_to_run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_human_message_has_follow_up_metadata(self):
|
||||
"""human_message event metadata includes follow_up_to_run_id."""
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
event_store = MemoryRunEventStore()
|
||||
await event_store.put(
|
||||
thread_id="t1",
|
||||
run_id="r2",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="Tell me more about that",
|
||||
metadata={"follow_up_to_run_id": "r1"},
|
||||
)
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_follow_up_auto_detection_logic(self):
|
||||
"""Simulate the auto-detection: latest successful run becomes follow_up_to."""
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
store = MemoryRunStore()
|
||||
await store.put("r1", thread_id="t1", status="success")
|
||||
await store.put("r2", thread_id="t1", status="error")
|
||||
|
||||
# Auto-detect: list_by_thread returns newest first
|
||||
recent = await store.list_by_thread("t1", limit=1)
|
||||
follow_up = None
|
||||
if recent and recent[0].get("status") == "success":
|
||||
follow_up = recent[0]["run_id"]
|
||||
# r2 (error) is newest, so no follow_up detected
|
||||
assert follow_up is None
|
||||
|
||||
# Now add a successful run
|
||||
await store.put("r3", thread_id="t1", status="success")
|
||||
recent = await store.list_by_thread("t1", limit=1)
|
||||
follow_up = None
|
||||
if recent and recent[0].get("status") == "success":
|
||||
follow_up = recent[0]["run_id"]
|
||||
assert follow_up == "r3"
|
||||
@ -1,342 +0,0 @@
|
||||
"""Tests for app.gateway.services — run lifecycle service layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def test_format_sse_basic():
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
frame = format_sse("metadata", {"run_id": "abc"})
|
||||
assert frame.startswith("event: metadata\n")
|
||||
assert "data: " in frame
|
||||
parsed = json.loads(frame.split("data: ")[1].split("\n")[0])
|
||||
assert parsed["run_id"] == "abc"
|
||||
|
||||
|
||||
def test_format_sse_with_event_id():
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
frame = format_sse("metadata", {"run_id": "abc"}, event_id="123-0")
|
||||
assert "id: 123-0" in frame
|
||||
|
||||
|
||||
def test_format_sse_end_event_null():
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
frame = format_sse("end", None)
|
||||
assert "data: null" in frame
|
||||
|
||||
|
||||
def test_format_sse_no_event_id():
|
||||
from app.gateway.services import format_sse
|
||||
|
||||
frame = format_sse("values", {"x": 1})
|
||||
assert "id:" not in frame
|
||||
|
||||
|
||||
def test_normalize_stream_modes_none():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
assert normalize_stream_modes(None) == ["values"]
|
||||
|
||||
|
||||
def test_normalize_stream_modes_string():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
assert normalize_stream_modes("messages-tuple") == ["messages-tuple"]
|
||||
|
||||
|
||||
def test_normalize_stream_modes_list():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
assert normalize_stream_modes(["values", "messages-tuple"]) == ["values", "messages-tuple"]
|
||||
|
||||
|
||||
def test_normalize_stream_modes_empty_list():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
assert normalize_stream_modes([]) == ["values"]
|
||||
|
||||
|
||||
def test_normalize_input_none():
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
assert normalize_input(None) == {}
|
||||
|
||||
|
||||
def test_normalize_input_with_messages():
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
result = normalize_input({"messages": [{"role": "user", "content": "hi"}]})
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].content == "hi"
|
||||
|
||||
|
||||
def test_normalize_input_passthrough():
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
result = normalize_input({"custom_key": "value"})
|
||||
assert result == {"custom_key": "value"}
|
||||
|
||||
|
||||
def test_build_run_config_basic():
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config("thread-1", None, None)
|
||||
assert config["configurable"]["thread_id"] == "thread-1"
|
||||
assert config["recursion_limit"] == 100
|
||||
|
||||
|
||||
def test_build_run_config_with_overrides():
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config(
|
||||
"thread-1",
|
||||
{"configurable": {"model_name": "gpt-4"}, "tags": ["test"]},
|
||||
{"user": "alice"},
|
||||
)
|
||||
assert config["configurable"]["model_name"] == "gpt-4"
|
||||
assert config["tags"] == ["test"]
|
||||
assert config["metadata"]["user"] == "alice"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression tests for issue #1644:
|
||||
# assistant_id not mapped to agent_name → custom agent SOUL.md never loaded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_run_config_custom_agent_injects_agent_name():
|
||||
"""Custom assistant_id must be forwarded as configurable['agent_name']."""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config("thread-1", None, None, assistant_id="finalis")
|
||||
assert config["configurable"]["agent_name"] == "finalis"
|
||||
|
||||
|
||||
def test_build_run_config_lead_agent_no_agent_name():
|
||||
"""'lead_agent' assistant_id must NOT inject configurable['agent_name']."""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config("thread-1", None, None, assistant_id="lead_agent")
|
||||
assert "agent_name" not in config["configurable"]
|
||||
|
||||
|
||||
def test_build_run_config_none_assistant_id_no_agent_name():
|
||||
"""None assistant_id must NOT inject configurable['agent_name']."""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config("thread-1", None, None, assistant_id=None)
|
||||
assert "agent_name" not in config["configurable"]
|
||||
|
||||
|
||||
def test_build_run_config_explicit_agent_name_not_overwritten():
|
||||
"""An explicit configurable['agent_name'] in the request must take precedence."""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config(
|
||||
"thread-1",
|
||||
{"configurable": {"agent_name": "explicit-agent"}},
|
||||
None,
|
||||
assistant_id="other-agent",
|
||||
)
|
||||
assert config["configurable"]["agent_name"] == "explicit-agent"
|
||||
|
||||
|
||||
def test_resolve_agent_factory_returns_make_lead_agent():
|
||||
"""resolve_agent_factory always returns make_lead_agent regardless of assistant_id."""
|
||||
from app.gateway.services import resolve_agent_factory
|
||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||
|
||||
assert resolve_agent_factory(None) is make_lead_agent
|
||||
assert resolve_agent_factory("lead_agent") is make_lead_agent
|
||||
assert resolve_agent_factory("finalis") is make_lead_agent
|
||||
assert resolve_agent_factory("custom-agent-123") is make_lead_agent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression tests for issue #1699:
|
||||
# context field in langgraph-compat requests not merged into configurable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_create_request_accepts_context():
|
||||
"""RunCreateRequest must accept the ``context`` field without dropping it."""
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
|
||||
body = RunCreateRequest(
|
||||
input={"messages": [{"role": "user", "content": "hi"}]},
|
||||
context={
|
||||
"model_name": "deepseek-v3",
|
||||
"thinking_enabled": True,
|
||||
"is_plan_mode": True,
|
||||
"subagent_enabled": True,
|
||||
"thread_id": "some-thread-id",
|
||||
},
|
||||
)
|
||||
assert body.context is not None
|
||||
assert body.context["model_name"] == "deepseek-v3"
|
||||
assert body.context["is_plan_mode"] is True
|
||||
assert body.context["subagent_enabled"] is True
|
||||
|
||||
|
||||
def test_run_create_request_context_defaults_to_none():
|
||||
"""RunCreateRequest without context should default to None (backward compat)."""
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
|
||||
body = RunCreateRequest(input=None)
|
||||
assert body.context is None
|
||||
|
||||
|
||||
def test_context_merges_into_configurable():
|
||||
"""Context values must be merged into config['configurable'] by start_run.
|
||||
|
||||
Since start_run is async and requires many dependencies, we test the
|
||||
merging logic directly by simulating what start_run does.
|
||||
"""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
# Simulate the context merging logic from start_run
|
||||
config = build_run_config("thread-1", None, None)
|
||||
|
||||
context = {
|
||||
"model_name": "deepseek-v3",
|
||||
"mode": "ultra",
|
||||
"reasoning_effort": "high",
|
||||
"thinking_enabled": True,
|
||||
"is_plan_mode": True,
|
||||
"subagent_enabled": True,
|
||||
"max_concurrent_subagents": 5,
|
||||
"thread_id": "should-be-ignored",
|
||||
}
|
||||
|
||||
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||
"model_name",
|
||||
"mode",
|
||||
"thinking_enabled",
|
||||
"reasoning_effort",
|
||||
"is_plan_mode",
|
||||
"subagent_enabled",
|
||||
"max_concurrent_subagents",
|
||||
}
|
||||
configurable = config.setdefault("configurable", {})
|
||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||
if key in context:
|
||||
configurable.setdefault(key, context[key])
|
||||
|
||||
assert config["configurable"]["model_name"] == "deepseek-v3"
|
||||
assert config["configurable"]["thinking_enabled"] is True
|
||||
assert config["configurable"]["is_plan_mode"] is True
|
||||
assert config["configurable"]["subagent_enabled"] is True
|
||||
assert config["configurable"]["max_concurrent_subagents"] == 5
|
||||
assert config["configurable"]["reasoning_effort"] == "high"
|
||||
assert config["configurable"]["mode"] == "ultra"
|
||||
# thread_id from context should NOT override the one from build_run_config
|
||||
assert config["configurable"]["thread_id"] == "thread-1"
|
||||
# Non-allowlisted keys should not appear
|
||||
assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS}
|
||||
|
||||
|
||||
def test_context_does_not_override_existing_configurable():
|
||||
"""Values already in config.configurable must NOT be overridden by context."""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config(
|
||||
"thread-1",
|
||||
{"configurable": {"model_name": "gpt-4", "is_plan_mode": False}},
|
||||
None,
|
||||
)
|
||||
|
||||
context = {
|
||||
"model_name": "deepseek-v3",
|
||||
"is_plan_mode": True,
|
||||
"subagent_enabled": True,
|
||||
}
|
||||
|
||||
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||
"model_name",
|
||||
"mode",
|
||||
"thinking_enabled",
|
||||
"reasoning_effort",
|
||||
"is_plan_mode",
|
||||
"subagent_enabled",
|
||||
"max_concurrent_subagents",
|
||||
}
|
||||
configurable = config.setdefault("configurable", {})
|
||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||
if key in context:
|
||||
configurable.setdefault(key, context[key])
|
||||
|
||||
# Existing values must NOT be overridden
|
||||
assert config["configurable"]["model_name"] == "gpt-4"
|
||||
assert config["configurable"]["is_plan_mode"] is False
|
||||
# New values should be added
|
||||
assert config["configurable"]["subagent_enabled"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_run_config_with_context():
|
||||
"""When caller sends 'context', prefer it over 'configurable'."""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config(
|
||||
"thread-1",
|
||||
{"context": {"user_id": "u-42", "thread_id": "thread-1"}},
|
||||
None,
|
||||
)
|
||||
assert "context" in config
|
||||
assert config["context"]["user_id"] == "u-42"
|
||||
assert "configurable" not in config
|
||||
assert config["recursion_limit"] == 100
|
||||
|
||||
|
||||
def test_build_run_config_context_plus_configurable_warns(caplog):
|
||||
"""When caller sends both 'context' and 'configurable', prefer 'context' and log a warning."""
|
||||
import logging
|
||||
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="app.gateway.services"):
|
||||
config = build_run_config(
|
||||
"thread-1",
|
||||
{
|
||||
"context": {"user_id": "u-42"},
|
||||
"configurable": {"model_name": "gpt-4"},
|
||||
},
|
||||
None,
|
||||
)
|
||||
assert "context" in config
|
||||
assert config["context"]["user_id"] == "u-42"
|
||||
assert "configurable" not in config
|
||||
assert any("both 'context' and 'configurable'" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_build_run_config_context_passthrough_other_keys():
|
||||
"""Non-conflicting keys from request_config are still passed through when context is used."""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config(
|
||||
"thread-1",
|
||||
{"context": {"thread_id": "thread-1"}, "tags": ["prod"]},
|
||||
None,
|
||||
)
|
||||
assert config["context"]["thread_id"] == "thread-1"
|
||||
assert "configurable" not in config
|
||||
assert config["tags"] == ["prod"]
|
||||
|
||||
|
||||
def test_build_run_config_no_request_config():
|
||||
"""When request_config is None, fall back to basic configurable with thread_id."""
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
config = build_run_config("thread-abc", None, None)
|
||||
assert config["configurable"] == {"thread_id": "thread-abc"}
|
||||
assert "context" not in config
|
||||
@ -1,156 +0,0 @@
|
||||
"""Owner isolation tests for MemoryThreadMetaStore.
|
||||
|
||||
Mirrors the SQL-backed tests in test_owner_isolation.py but exercises
|
||||
the in-memory LangGraph Store backend used when database.backend=memory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
||||
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
||||
|
||||
|
||||
def _as_user(user):
|
||||
class _Ctx:
|
||||
def __enter__(self):
|
||||
self._token = set_current_user(user)
|
||||
return user
|
||||
|
||||
def __exit__(self, *exc):
|
||||
reset_current_user(self._token)
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return MemoryThreadMetaStore(InMemoryStore())
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_search_isolation(store):
|
||||
"""search() returns only threads owned by the current user."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="A's thread")
|
||||
with _as_user(USER_B):
|
||||
await store.create("t-beta", display_name="B's thread")
|
||||
|
||||
with _as_user(USER_A):
|
||||
results = await store.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
||||
|
||||
with _as_user(USER_B):
|
||||
results = await store.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-beta"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_get_isolation(store):
|
||||
"""get() returns None for threads owned by another user."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="A's thread")
|
||||
|
||||
with _as_user(USER_B):
|
||||
assert await store.get("t-alpha") is None
|
||||
|
||||
with _as_user(USER_A):
|
||||
result = await store.get("t-alpha")
|
||||
assert result is not None
|
||||
assert result["display_name"] == "A's thread"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_display_name_denied(store):
|
||||
"""User B cannot rename User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="original")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_display_name("t-alpha", "hacked")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["display_name"] == "original"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_status_denied(store):
|
||||
"""User B cannot change status of User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_status("t-alpha", "error")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["status"] == "idle"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_metadata_denied(store):
|
||||
"""User B cannot modify metadata of User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", metadata={"key": "original"})
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_metadata("t-alpha", {"key": "hacked"})
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["metadata"]["key"] == "original"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_delete_denied(store):
|
||||
"""User B cannot delete User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.delete("t-alpha")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_no_context_raises(store):
|
||||
"""Calling methods without user context raises RuntimeError."""
|
||||
with pytest.raises(RuntimeError, match="no user context is set"):
|
||||
await store.search()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(store):
|
||||
"""user_id=None bypasses isolation (migration/CLI escape hatch)."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await store.create("t-beta")
|
||||
|
||||
all_rows = await store.search(user_id=None)
|
||||
assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"}
|
||||
|
||||
row = await store.get("t-alpha", user_id=None)
|
||||
assert row is not None
|
||||
@ -1,465 +0,0 @@
|
||||
"""Cross-user isolation tests — non-negotiable safety gate.
|
||||
|
||||
Mirrors TC-API-17..20 from backend/docs/AUTH_TEST_PLAN.md. A failure
|
||||
here means users can see each other's data; PR must not merge.
|
||||
|
||||
Architecture note
|
||||
-----------------
|
||||
These tests bypass the HTTP layer and exercise the storage-layer
|
||||
owner filter directly by switching the ``user_context`` contextvar
|
||||
between two users. The safety property under test is:
|
||||
|
||||
After a repository write with user_id=A, a subsequent read with
|
||||
user_id=B must not return the row, and vice versa.
|
||||
|
||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||
that a request cookie reaches the ``set_current_user`` call. Together
|
||||
the two suites prove the full chain:
|
||||
|
||||
cookie → middleware → contextvar → repository → isolation
|
||||
|
||||
Every test in this file opts out of the autouse contextvar fixture
|
||||
(``@pytest.mark.no_auto_user``) so it can set the contextvar to the
|
||||
specific users it cares about.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.user_context import (
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
|
||||
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
||||
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
||||
|
||||
|
||||
async def _make_engines(tmp_path):
|
||||
"""Initialize the shared engine against a per-test SQLite DB.
|
||||
|
||||
Returns a cleanup coroutine the caller should await at the end.
|
||||
"""
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'isolation.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return close_engine
|
||||
|
||||
|
||||
def _as_user(user):
|
||||
"""Context manager-like helper that set/reset the contextvar."""
|
||||
|
||||
class _Ctx:
|
||||
def __enter__(self):
|
||||
self._token = set_current_user(user)
|
||||
return user
|
||||
|
||||
def __exit__(self, *exc):
|
||||
reset_current_user(self._token)
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
# ── TC-API-17 — threads_meta isolation ────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_thread_meta_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
# User A creates a thread.
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha", display_name="A's private thread")
|
||||
|
||||
# User B creates a thread.
|
||||
with _as_user(USER_B):
|
||||
await repo.create("t-beta", display_name="B's private thread")
|
||||
|
||||
# User A must see only A's thread.
|
||||
with _as_user(USER_A):
|
||||
a_view = await repo.get("t-alpha")
|
||||
assert a_view is not None
|
||||
assert a_view["display_name"] == "A's private thread"
|
||||
|
||||
# CRITICAL: User A must NOT see B's thread.
|
||||
leaked = await repo.get("t-beta")
|
||||
assert leaked is None, f"User A leaked User B's thread: {leaked}"
|
||||
|
||||
# Search should only return A's threads.
|
||||
results = await repo.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
||||
|
||||
# User B must see only B's thread.
|
||||
with _as_user(USER_B):
|
||||
b_view = await repo.get("t-beta")
|
||||
assert b_view is not None
|
||||
assert b_view["display_name"] == "B's private thread"
|
||||
|
||||
leaked = await repo.get("t-alpha")
|
||||
assert leaked is None, f"User B leaked User A's thread: {leaked}"
|
||||
|
||||
results = await repo.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-beta"]
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_thread_meta_cross_user_mutation_denied(tmp_path):
|
||||
"""User B cannot update or delete a thread owned by User A."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha", display_name="original")
|
||||
|
||||
# User B tries to rename A's thread — must be a no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.update_display_name("t-alpha", "hacked")
|
||||
|
||||
# Verify the row is unchanged from A's perspective.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["display_name"] == "original"
|
||||
|
||||
# User B tries to delete A's thread — must be a no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.delete("t-alpha")
|
||||
|
||||
# A's thread still exists.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("t-alpha")
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-18 — runs isolation ────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_runs_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = RunRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.put("run-a1", thread_id="t-alpha")
|
||||
await repo.put("run-a2", thread_id="t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await repo.put("run-b1", thread_id="t-beta")
|
||||
|
||||
# User A must see only A's runs.
|
||||
with _as_user(USER_A):
|
||||
r = await repo.get("run-a1")
|
||||
assert r is not None
|
||||
assert r["run_id"] == "run-a1"
|
||||
|
||||
leaked = await repo.get("run-b1")
|
||||
assert leaked is None, "User A leaked User B's run"
|
||||
|
||||
a_runs = await repo.list_by_thread("t-alpha")
|
||||
assert {r["run_id"] for r in a_runs} == {"run-a1", "run-a2"}
|
||||
|
||||
# Listing B's thread from A's perspective: empty
|
||||
empty = await repo.list_by_thread("t-beta")
|
||||
assert empty == []
|
||||
|
||||
# User B must see only B's runs.
|
||||
with _as_user(USER_B):
|
||||
leaked = await repo.get("run-a1")
|
||||
assert leaked is None, "User B leaked User A's run"
|
||||
|
||||
b_runs = await repo.list_by_thread("t-beta")
|
||||
assert [r["run_id"] for r in b_runs] == ["run-b1"]
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_runs_cross_user_delete_denied(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = RunRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await repo.put("run-a1", thread_id="t-alpha")
|
||||
|
||||
# User B tries to delete A's run — no-op.
|
||||
with _as_user(USER_B):
|
||||
await repo.delete("run-a1")
|
||||
|
||||
# A's run still exists.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get("run-a1")
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-19 — run_events isolation (CRITICAL: content leak) ─────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_run_events_cross_user_isolation(tmp_path):
|
||||
"""run_events holds raw conversation content — most sensitive leak vector."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
store = DbRunEventStore(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="User A private question",
|
||||
)
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content="User A private answer",
|
||||
)
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.put(
|
||||
thread_id="t-beta",
|
||||
run_id="run-b1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="User B private question",
|
||||
)
|
||||
|
||||
# User A must see only A's events — CRITICAL.
|
||||
with _as_user(USER_A):
|
||||
msgs = await store.list_messages("t-alpha")
|
||||
contents = [m["content"] for m in msgs]
|
||||
assert "User A private question" in contents
|
||||
assert "User A private answer" in contents
|
||||
# CRITICAL: User B's content must not appear.
|
||||
assert "User B private question" not in contents
|
||||
|
||||
# Attempt to read B's thread by guessing thread_id.
|
||||
leaked = await store.list_messages("t-beta")
|
||||
assert leaked == [], f"User A leaked User B's messages: {leaked}"
|
||||
|
||||
leaked_events = await store.list_events("t-beta", "run-b1")
|
||||
assert leaked_events == [], "User A leaked User B's events"
|
||||
|
||||
# count_messages must also be zero for B's thread from A's view.
|
||||
count = await store.count_messages("t-beta")
|
||||
assert count == 0
|
||||
|
||||
# User B must see only B's events.
|
||||
with _as_user(USER_B):
|
||||
msgs = await store.list_messages("t-beta")
|
||||
contents = [m["content"] for m in msgs]
|
||||
assert "User B private question" in contents
|
||||
assert "User A private question" not in contents
|
||||
assert "User A private answer" not in contents
|
||||
|
||||
count = await store.count_messages("t-alpha")
|
||||
assert count == 0
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_run_events_cross_user_delete_denied(tmp_path):
|
||||
"""User B cannot delete User A's event stream."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
store = DbRunEventStore(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
await store.put(
|
||||
thread_id="t-alpha",
|
||||
run_id="run-a1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="hello",
|
||||
)
|
||||
|
||||
# User B tries to wipe A's thread events.
|
||||
with _as_user(USER_B):
|
||||
removed = await store.delete_by_thread("t-alpha")
|
||||
assert removed == 0, f"User B deleted {removed} of User A's events"
|
||||
|
||||
# A's events still exist.
|
||||
with _as_user(USER_A):
|
||||
count = await store.count_messages("t-alpha")
|
||||
assert count == 1
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── TC-API-20 — feedback isolation ────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_feedback_cross_user_isolation(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = FeedbackRepository(get_session_factory())
|
||||
|
||||
# User A submits positive feedback.
|
||||
with _as_user(USER_A):
|
||||
a_feedback = await repo.create(
|
||||
run_id="run-a1",
|
||||
thread_id="t-alpha",
|
||||
rating=1,
|
||||
comment="A liked this",
|
||||
)
|
||||
|
||||
# User B submits negative feedback.
|
||||
with _as_user(USER_B):
|
||||
b_feedback = await repo.create(
|
||||
run_id="run-b1",
|
||||
thread_id="t-beta",
|
||||
rating=-1,
|
||||
comment="B disliked this",
|
||||
)
|
||||
|
||||
# User A must see only A's feedback.
|
||||
with _as_user(USER_A):
|
||||
retrieved = await repo.get(a_feedback["feedback_id"])
|
||||
assert retrieved is not None
|
||||
assert retrieved["comment"] == "A liked this"
|
||||
|
||||
# CRITICAL: cannot read B's feedback by id.
|
||||
leaked = await repo.get(b_feedback["feedback_id"])
|
||||
assert leaked is None, "User A leaked User B's feedback"
|
||||
|
||||
# list_by_run for B's run must be empty.
|
||||
empty = await repo.list_by_run("t-beta", "run-b1")
|
||||
assert empty == []
|
||||
|
||||
# User B must see only B's feedback.
|
||||
with _as_user(USER_B):
|
||||
leaked = await repo.get(a_feedback["feedback_id"])
|
||||
assert leaked is None, "User B leaked User A's feedback"
|
||||
|
||||
b_list = await repo.list_by_run("t-beta", "run-b1")
|
||||
assert len(b_list) == 1
|
||||
assert b_list[0]["comment"] == "B disliked this"
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_feedback_cross_user_delete_denied(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = FeedbackRepository(get_session_factory())
|
||||
|
||||
with _as_user(USER_A):
|
||||
fb = await repo.create(run_id="run-a1", thread_id="t-alpha", rating=1)
|
||||
|
||||
# User B tries to delete A's feedback — must return False (no-op).
|
||||
with _as_user(USER_B):
|
||||
deleted = await repo.delete(fb["feedback_id"])
|
||||
assert deleted is False, "User B deleted User A's feedback"
|
||||
|
||||
# A's feedback still retrievable.
|
||||
with _as_user(USER_A):
|
||||
row = await repo.get(fb["feedback_id"])
|
||||
assert row is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Regression: AUTO sentinel without contextvar must raise ───────────────
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_repository_without_context_raises(tmp_path):
|
||||
"""Defense-in-depth: calling repo methods without a user context errors."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
# Contextvar is explicitly unset under @pytest.mark.no_auto_user.
|
||||
with pytest.raises(RuntimeError, match="no user context is set"):
|
||||
await repo.get("anything")
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
cleanup = await _make_engines(tmp_path)
|
||||
try:
|
||||
repo = ThreadMetaRepository(get_session_factory())
|
||||
|
||||
# Seed data as two different users.
|
||||
with _as_user(USER_A):
|
||||
await repo.create("t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await repo.create("t-beta")
|
||||
|
||||
# Migration-style read: no contextvar, explicit None bypass.
|
||||
all_rows = await repo.search(user_id=None)
|
||||
thread_ids = {r["thread_id"] for r in all_rows}
|
||||
assert thread_ids == {"t-alpha", "t-beta"}
|
||||
|
||||
# Explicit get with None does not apply the filter either.
|
||||
row_a = await repo.get("t-alpha", user_id=None)
|
||||
assert row_a is not None
|
||||
row_b = await repo.get("t-beta", user_id=None)
|
||||
assert row_b is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
@ -1,233 +0,0 @@
|
||||
"""Tests for the persistence layer scaffolding.
|
||||
|
||||
Tests:
|
||||
1. DatabaseConfig property derivation (paths, URLs)
|
||||
2. MemoryRunStore CRUD + user_id filtering
|
||||
3. Base.to_dict() via inspect mixin
|
||||
4. Engine init/close lifecycle (memory + SQLite)
|
||||
5. Postgres missing-dep error message
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
# -- DatabaseConfig --
|
||||
|
||||
|
||||
class TestDatabaseConfig:
|
||||
def test_defaults(self):
|
||||
c = DatabaseConfig()
|
||||
assert c.backend == "memory"
|
||||
assert c.pool_size == 5
|
||||
|
||||
def test_sqlite_paths_unified(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
||||
assert c.sqlite_path.endswith("deerflow.db")
|
||||
assert "mydata" in c.sqlite_path
|
||||
# Backward-compatible aliases point to the same file
|
||||
assert c.checkpointer_sqlite_path == c.sqlite_path
|
||||
assert c.app_sqlite_path == c.sqlite_path
|
||||
|
||||
def test_app_sqlalchemy_url_sqlite(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.startswith("sqlite+aiosqlite:///")
|
||||
assert "deerflow.db" in url
|
||||
|
||||
def test_app_sqlalchemy_url_postgres(self):
|
||||
c = DatabaseConfig(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql://u:p@h:5432/db",
|
||||
)
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.startswith("postgresql+asyncpg://")
|
||||
assert "u:p@h:5432/db" in url
|
||||
|
||||
def test_app_sqlalchemy_url_postgres_already_asyncpg(self):
|
||||
c = DatabaseConfig(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql+asyncpg://u:p@h:5432/db",
|
||||
)
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.count("asyncpg") == 1
|
||||
|
||||
def test_memory_has_no_url(self):
|
||||
c = DatabaseConfig(backend="memory")
|
||||
with pytest.raises(ValueError, match="No SQLAlchemy URL"):
|
||||
_ = c.app_sqlalchemy_url
|
||||
|
||||
|
||||
# -- MemoryRunStore --
|
||||
|
||||
|
||||
class TestMemoryRunStore:
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
return MemoryRunStore()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, store):
|
||||
await store.put("r1", thread_id="t1", status="pending")
|
||||
row = await store.get("r1")
|
||||
assert row is not None
|
||||
assert row["run_id"] == "r1"
|
||||
assert row["status"] == "pending"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_missing_returns_none(self, store):
|
||||
assert await store.get("nope") is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.update_status("r1", "running")
|
||||
assert (await store.get("r1"))["status"] == "running"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status_with_error(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.update_status("r1", "error", error="boom")
|
||||
row = await store.get("r1")
|
||||
assert row["status"] == "error"
|
||||
assert row["error"] == "boom"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.put("r2", thread_id="t1")
|
||||
await store.put("r3", thread_id="t2")
|
||||
rows = await store.list_by_thread("t1")
|
||||
assert len(rows) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in rows)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, store):
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, store):
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id=None)
|
||||
assert len(rows) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, store):
|
||||
await store.put("r1", thread_id="t1")
|
||||
await store.delete("r1")
|
||||
assert await store.get("r1") is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, store):
|
||||
await store.delete("nope") # should not raise
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending(self, store):
|
||||
await store.put("r1", thread_id="t1", status="pending")
|
||||
await store.put("r2", thread_id="t1", status="running")
|
||||
await store.put("r3", thread_id="t2", status="pending")
|
||||
pending = await store.list_pending()
|
||||
assert len(pending) == 2
|
||||
assert all(r["status"] == "pending" for r in pending)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending_respects_before(self, store):
|
||||
past = "2020-01-01T00:00:00+00:00"
|
||||
future = "2099-01-01T00:00:00+00:00"
|
||||
await store.put("r1", thread_id="t1", status="pending", created_at=past)
|
||||
await store.put("r2", thread_id="t1", status="pending", created_at=future)
|
||||
pending = await store.list_pending(before=datetime.now(UTC).isoformat())
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending_fifo_order(self, store):
|
||||
await store.put("r2", thread_id="t1", status="pending", created_at="2024-01-02T00:00:00+00:00")
|
||||
await store.put("r1", thread_id="t1", status="pending", created_at="2024-01-01T00:00:00+00:00")
|
||||
pending = await store.list_pending()
|
||||
assert pending[0]["run_id"] == "r1"
|
||||
|
||||
|
||||
# -- Base.to_dict mixin --
|
||||
|
||||
|
||||
class TestBaseToDictMixin:
|
||||
@pytest.mark.anyio
|
||||
async def test_to_dict_and_exclude(self, tmp_path):
|
||||
"""Create a temp SQLite DB with a minimal model, verify to_dict."""
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
class _Tmp(Base):
|
||||
__tablename__ = "_tmp_test"
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(128))
|
||||
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
sf = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with sf() as session:
|
||||
session.add(_Tmp(id="1", name="hello"))
|
||||
await session.commit()
|
||||
obj = await session.get(_Tmp, "1")
|
||||
|
||||
assert obj.to_dict() == {"id": "1", "name": "hello"}
|
||||
assert obj.to_dict(exclude={"name"}) == {"id": "1"}
|
||||
assert "_Tmp" in repr(obj)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# -- Engine lifecycle --
|
||||
|
||||
|
||||
class TestEngineLifecycle:
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_is_noop(self):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
await init_engine("memory")
|
||||
assert get_session_factory() is None
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_engine(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
sf = get_session_factory()
|
||||
assert sf is not None
|
||||
async with sf() as session:
|
||||
assert session is not None
|
||||
await close_engine()
|
||||
assert get_session_factory() is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_postgres_without_asyncpg_gives_actionable_error(self):
|
||||
"""If asyncpg is not installed, error message tells user what to do."""
|
||||
from deerflow.persistence.engine import init_engine
|
||||
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
|
||||
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
|
||||
except ImportError:
|
||||
# asyncpg is not installed — this is the expected state for this test.
|
||||
# We proceed to verify that init_engine raises an actionable ImportError.
|
||||
pass # noqa: S110 — intentionally ignored
|
||||
with pytest.raises(ImportError, match="uv sync --extra postgres"):
|
||||
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
|
||||
@ -1,500 +0,0 @@
|
||||
"""Tests for RunEventStore contract across all backends.
|
||||
|
||||
Uses a helper to create the store for each backend type.
|
||||
Memory tests run directly; DB and JSONL tests create stores inside each test.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
# -- Basic write and query --
|
||||
|
||||
|
||||
class TestPutAndSeq:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_returns_dict_with_seq(self, store):
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello")
|
||||
assert "seq" in record
|
||||
assert record["seq"] == 1
|
||||
assert record["thread_id"] == "t1"
|
||||
assert record["run_id"] == "r1"
|
||||
assert record["event_type"] == "human_message"
|
||||
assert record["category"] == "message"
|
||||
assert record["content"] == "hello"
|
||||
assert "created_at" in record
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_seq_strictly_increasing_same_thread(self, store):
|
||||
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
r2 = await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
r3 = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
assert r1["seq"] == 1
|
||||
assert r2["seq"] == 2
|
||||
assert r3["seq"] == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_seq_independent_across_threads(self, store):
|
||||
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
r2 = await store.put(thread_id="t2", run_id="r2", event_type="human_message", category="message")
|
||||
assert r1["seq"] == 1
|
||||
assert r2["seq"] == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_respects_provided_created_at(self, store):
|
||||
ts = "2024-06-01T12:00:00+00:00"
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", created_at=ts)
|
||||
assert record["created_at"] == ts
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_metadata_preserved(self, store):
|
||||
meta = {"model": "gpt-4", "tokens": 100}
|
||||
record = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", metadata=meta)
|
||||
assert record["metadata"] == meta
|
||||
|
||||
|
||||
# -- list_messages --
|
||||
|
||||
|
||||
class TestListMessages:
|
||||
@pytest.mark.anyio
|
||||
async def test_only_returns_message_category(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["category"] == "message"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ascending_seq_order(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="first")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="second")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="third")
|
||||
messages = await store.list_messages("t1")
|
||||
seqs = [m["seq"] for m in messages]
|
||||
assert seqs == sorted(seqs)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_before_seq_pagination(self, store):
|
||||
for i in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
messages = await store.list_messages("t1", before_seq=6, limit=3)
|
||||
assert len(messages) == 3
|
||||
assert [m["seq"] for m in messages] == [3, 4, 5]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_after_seq_pagination(self, store):
|
||||
for i in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
messages = await store.list_messages("t1", after_seq=7, limit=3)
|
||||
assert len(messages) == 3
|
||||
assert [m["seq"] for m in messages] == [8, 9, 10]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_limit_restricts_count(self, store):
|
||||
for _ in range(20):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
messages = await store.list_messages("t1", limit=5)
|
||||
assert len(messages) == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_run_unified_ordering(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
|
||||
messages = await store.list_messages("t1")
|
||||
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[2]["run_id"] == "r2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_returns_latest(self, store):
|
||||
for _ in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
messages = await store.list_messages("t1", limit=3)
|
||||
assert [m["seq"] for m in messages] == [8, 9, 10]
|
||||
|
||||
|
||||
# -- list_events --
|
||||
|
||||
|
||||
class TestListEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_returns_all_categories_for_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_event_types_filter(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_start", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="tool_start", category="trace")
|
||||
events = await store.list_events("t1", "r1", event_types=["llm_end"])
|
||||
assert len(events) == 1
|
||||
assert events[0]["event_type"] == "llm_end"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_only_returns_specified_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) == 1
|
||||
assert events[0]["run_id"] == "r1"
|
||||
|
||||
|
||||
# -- list_messages_by_run --
|
||||
|
||||
|
||||
class TestListMessagesByRun:
|
||||
@pytest.mark.anyio
|
||||
async def test_only_messages_for_specified_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
messages = await store.list_messages_by_run("t1", "r1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[0]["category"] == "message"
|
||||
|
||||
|
||||
# -- count_messages --
|
||||
|
||||
|
||||
class TestCountMessages:
|
||||
@pytest.mark.anyio
|
||||
async def test_counts_only_message_category(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
|
||||
assert await store.count_messages("t1") == 2
|
||||
|
||||
|
||||
# -- put_batch --
|
||||
|
||||
|
||||
class TestPutBatch:
|
||||
@pytest.mark.anyio
|
||||
async def test_batch_assigns_seq(self, store):
|
||||
events = [
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message", "content": "a"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message", "content": "b"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "llm_end", "category": "trace"},
|
||||
]
|
||||
results = await store.put_batch(events)
|
||||
assert len(results) == 3
|
||||
assert all("seq" in r for r in results)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_batch_seq_strictly_increasing(self, store):
|
||||
events = [
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message"},
|
||||
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message"},
|
||||
]
|
||||
results = await store.put_batch(events)
|
||||
assert results[0]["seq"] == 1
|
||||
assert results[1]["seq"] == 2
|
||||
|
||||
|
||||
# -- delete --
|
||||
|
||||
|
||||
class TestDelete:
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_thread(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
count = await store.delete_by_thread("t1")
|
||||
assert count == 3
|
||||
assert await store.list_messages("t1") == []
|
||||
assert await store.count_messages("t1") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
count = await store.delete_by_run("t1", "r2")
|
||||
assert count == 2
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_thread_returns_zero(self, store):
|
||||
assert await store.delete_by_thread("nope") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_run_returns_zero(self, store):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
assert await store.delete_by_run("t1", "nope") == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_thread_for_run_returns_zero(self, store):
|
||||
assert await store.delete_by_run("nope", "r1") == 0
|
||||
|
||||
|
||||
# -- Edge cases --
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_thread_list_messages(self, store):
|
||||
assert await store.list_messages("empty") == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_run_list_events(self, store):
|
||||
assert await store.list_events("empty", "r1") == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_thread_count_messages(self, store):
|
||||
assert await store.count_messages("empty") == 0
|
||||
|
||||
|
||||
# -- DB-specific tests --
|
||||
|
||||
|
||||
class TestDbRunEventStore:
|
||||
"""Tests for DbRunEventStore with temp SQLite."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello")
|
||||
assert r2["seq"] == 2
|
||||
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
|
||||
count = await s.count_messages("t1")
|
||||
assert count == 2
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_trace_content_truncation(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory(), max_trace_content=100)
|
||||
|
||||
long = "x" * 200
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long)
|
||||
assert len(r["content"]) == 100
|
||||
assert r["metadata"].get("content_truncated") is True
|
||||
|
||||
# message content NOT truncated
|
||||
m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long)
|
||||
assert len(m["content"]) == 200
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
for i in range(10):
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
|
||||
# before_seq
|
||||
msgs = await s.list_messages("t1", before_seq=6, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [3, 4, 5]
|
||||
|
||||
# after_seq
|
||||
msgs = await s.list_messages("t1", after_seq=7, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
# default (latest)
|
||||
msgs = await s.list_messages("t1", limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 1
|
||||
|
||||
c = await s.delete_by_thread("t1")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 0
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_put_batch_seq_continuity(self, tmp_path):
|
||||
"""Batch write produces continuous seq values with no gaps."""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"} for _ in range(50)]
|
||||
results = await s.put_batch(events)
|
||||
seqs = [r["seq"] for r in results]
|
||||
assert seqs == list(range(1, 51))
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- Factory tests --
|
||||
|
||||
|
||||
class TestMakeRunEventStore:
|
||||
"""Tests for the make_run_event_store factory function."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_backend_default(self):
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
store = make_run_event_store(None)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_backend_explicit(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "memory"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_backend_with_engine(self, tmp_path):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "db"
|
||||
config.max_trace_content = 10240
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "DbRunEventStore"
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_backend_no_engine_falls_back(self):
|
||||
"""db backend without engine falls back to memory."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
await init_engine("memory") # no engine created
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "db"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "MemoryRunEventStore"
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_jsonl_backend(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "jsonl"
|
||||
store = make_run_event_store(config)
|
||||
assert type(store).__name__ == "JsonlRunEventStore"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_unknown_backend_raises(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
config = MagicMock()
|
||||
config.backend = "redis"
|
||||
with pytest.raises(ValueError, match="Unknown"):
|
||||
make_run_event_store(config)
|
||||
|
||||
|
||||
# -- JSONL-specific tests --
|
||||
|
||||
|
||||
class TestJsonlRunEventStore:
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_file_at_correct_path(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_run_messages(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert [m["seq"] for m in messages] == [1, 2]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists()
|
||||
assert await s.count_messages("t1") == 1
|
||||
@ -1,107 +0,0 @@
|
||||
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_store():
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_default_returns_all(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
assert len(msgs) == 7
|
||||
assert all(m["category"] == "message" for m in msgs)
|
||||
assert all(m["run_id"] == "run-a" for m in msgs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_with_limit(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
|
||||
assert len(msgs) == 3
|
||||
seqs = [m["seq"] for m in msgs]
|
||||
assert seqs == sorted(seqs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_after_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
cursor_seq = all_msgs[2]["seq"]
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50)
|
||||
assert all(m["seq"] > cursor_seq for m in msgs)
|
||||
assert len(msgs) == 4
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_before_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
cursor_seq = all_msgs[4]["seq"]
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50)
|
||||
assert all(m["seq"] < cursor_seq for m in msgs)
|
||||
assert len(msgs) == 4
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_does_not_include_other_run(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message", category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-b")
|
||||
assert len(msgs) == 3
|
||||
assert all(m["run_id"] == "run-b" for m in msgs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_empty_run(base_store):
|
||||
store = base_store
|
||||
msgs = await store.list_messages_by_run("t1", "nonexistent")
|
||||
assert msgs == []
|
||||
@ -1,385 +0,0 @@
|
||||
"""Tests for RunJournal callback handler.
|
||||
|
||||
Uses MemoryRunEventStore as the backend for direct event inspection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def journal_setup():
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal("r1", "t1", store, flush_threshold=100)
|
||||
return j, store
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
|
||||
"""Create a mock LLM response with a message.
|
||||
|
||||
model_dump() returns checkpoint-aligned format matching real AIMessage.
|
||||
"""
|
||||
msg = MagicMock()
|
||||
msg.type = "ai"
|
||||
msg.content = content
|
||||
msg.id = f"msg-{id(msg)}"
|
||||
msg.tool_calls = tool_calls or []
|
||||
msg.invalid_tool_calls = []
|
||||
msg.response_metadata = {"model_name": "test-model"}
|
||||
msg.usage_metadata = usage
|
||||
msg.additional_kwargs = additional_kwargs or {}
|
||||
msg.name = None
|
||||
# model_dump returns checkpoint-aligned format
|
||||
msg.model_dump.return_value = {
|
||||
"content": content,
|
||||
"additional_kwargs": additional_kwargs or {},
|
||||
"response_metadata": {"model_name": "test-model"},
|
||||
"type": "ai",
|
||||
"name": None,
|
||||
"id": msg.id,
|
||||
"tool_calls": tool_calls or [],
|
||||
"invalid_tool_calls": [],
|
||||
"usage_metadata": usage,
|
||||
}
|
||||
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
|
||||
response = MagicMock()
|
||||
response.generations = [[gen]]
|
||||
return response
|
||||
|
||||
|
||||
class TestLlmCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_produces_trace_event(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
trace_events = [e for e in events if e["event_type"] == "llm.ai.response"]
|
||||
assert len(trace_events) == 1
|
||||
assert trace_events[0]["category"] == "message"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "llm.ai.response"
|
||||
# Content is checkpoint-aligned model_dump format
|
||||
assert messages[0]["content"]["type"] == "ai"
|
||||
assert messages[0]["content"]["content"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
|
||||
"""LLM response with pending tool_calls emits llm.ai.response with tool_calls in content."""
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Let me search", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]),
|
||||
run_id=run_id,
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "llm.ai.response"
|
||||
assert len(messages[0]["content"]["tool_calls"]) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
|
||||
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, parent_run_id=None, tags=["subagent:research"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
# subagent responses still emit llm.ai.response with category="message"
|
||||
assert len(messages) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_accumulation(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
assert j._total_input_tokens == 30
|
||||
assert j._total_output_tokens == 15
|
||||
assert j._total_tokens == 45
|
||||
assert j._llm_call_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_total_tokens_computed_from_input_output(self, journal_setup):
|
||||
"""If total_tokens is 0, it should be computed from input + output."""
|
||||
j, store = journal_setup
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
assert j._total_tokens == 150
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_caller_token_classification(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
||||
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarization"])
|
||||
# token tracking not broken by caller type
|
||||
assert j._total_tokens == 45
|
||||
assert j._llm_call_count == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_usage_metadata_none_no_crash(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_latency_tracking(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
llm_resp = [e for e in events if e["event_type"] == "llm.ai.response"][0]
|
||||
assert "latency_ms" in llm_resp["metadata"]
|
||||
assert llm_resp["metadata"]["latency_ms"] is not None
|
||||
|
||||
|
||||
class TestLifecycleCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_chain_start_end_produce_trace_events(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
|
||||
j.on_chain_end({}, run_id=uuid4())
|
||||
await asyncio.sleep(0.05)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
types = {e["event_type"] for e in events}
|
||||
assert "run.start" in types
|
||||
assert "run.end" in types
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_nested_chain_no_run_start(self, journal_setup):
|
||||
"""Nested chains (parent_run_id set) should NOT produce run.start."""
|
||||
j, store = journal_setup
|
||||
parent_id = uuid4()
|
||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
||||
j.on_chain_end({}, run_id=uuid4())
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert not any(e["event_type"] == "run.start" for e in events)
|
||||
|
||||
|
||||
class TestToolCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_end_with_tool_message(self, journal_setup):
|
||||
"""on_tool_end with a ToolMessage stores it as llm.tool.result."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
tool_msg = ToolMessage(content="results", tool_call_id="call_1", name="web_search")
|
||||
j.on_tool_end(tool_msg, run_id=uuid4())
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "llm.tool.result"
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_end_with_command_unwraps_tool_message(self, journal_setup):
|
||||
"""on_tool_end with Command(update={'messages':[ToolMessage]}) unwraps inner message."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
j, store = journal_setup
|
||||
inner = ToolMessage(content="file list", tool_call_id="call_2", name="present_files")
|
||||
cmd = Command(update={"messages": [inner]})
|
||||
j.on_tool_end(cmd, run_id=uuid4())
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "llm.tool.result"
|
||||
assert messages[0]["content"]["content"] == "file list"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_tool_error_no_crash(self, journal_setup):
|
||||
"""on_tool_error should not crash (no event emitted by default)."""
|
||||
j, store = journal_setup
|
||||
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
|
||||
await j.flush()
|
||||
# Base implementation does not emit tool_error — just verify no crash
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert isinstance(events, list)
|
||||
|
||||
|
||||
class TestCustomEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_custom_event_not_implemented(self, journal_setup):
|
||||
"""RunJournal does not implement on_custom_event — no crash expected."""
|
||||
j, store = journal_setup
|
||||
# BaseCallbackHandler.on_custom_event is a no-op by default
|
||||
j.on_custom_event("task_running", {"task_id": "t1"}, run_id=uuid4())
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert isinstance(events, list)
|
||||
|
||||
|
||||
class TestBufferFlush:
|
||||
@pytest.mark.anyio
|
||||
async def test_flush_threshold(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j._flush_threshold = 2
|
||||
# Each on_llm_end emits 1 event
|
||||
j.on_llm_end(_make_llm_response("A"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
assert len(j._buffer) == 1
|
||||
j.on_llm_end(_make_llm_response("B"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
# At threshold the buffer should have been flushed asynchronously
|
||||
await asyncio.sleep(0.1)
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) >= 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_events_retained_when_no_loop(self, journal_setup):
|
||||
"""Events buffered in a sync (no-loop) context should survive
|
||||
until the async flush() in the finally block."""
|
||||
j, store = journal_setup
|
||||
j._flush_threshold = 1
|
||||
|
||||
original = asyncio.get_running_loop
|
||||
|
||||
def no_loop():
|
||||
raise RuntimeError("no running event loop")
|
||||
|
||||
asyncio.get_running_loop = no_loop
|
||||
try:
|
||||
j._put(event_type="llm.ai.response", category="message", content="test")
|
||||
finally:
|
||||
asyncio.get_running_loop = original
|
||||
|
||||
assert len(j._buffer) == 1
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert any(e["event_type"] == "llm.ai.response" for e in events)
|
||||
|
||||
|
||||
class TestIdentifyCaller:
|
||||
def test_lead_agent_tag(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
assert j._identify_caller(["lead_agent"]) == "lead_agent"
|
||||
|
||||
def test_subagent_tag(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
assert j._identify_caller(["subagent:research"]) == "subagent:research"
|
||||
|
||||
def test_middleware_tag(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
assert j._identify_caller(["middleware:summarization"]) == "middleware:summarization"
|
||||
|
||||
def test_no_tags_returns_lead_agent(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
assert j._identify_caller([]) == "lead_agent"
|
||||
assert j._identify_caller(None) == "lead_agent"
|
||||
|
||||
|
||||
class TestChainErrorCallback:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_chain_error_writes_run_error(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_chain_error(ValueError("boom"), run_id=uuid4())
|
||||
await asyncio.sleep(0.05)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
error_events = [e for e in events if e["event_type"] == "run.error"]
|
||||
assert len(error_events) == 1
|
||||
assert "boom" in error_events[0]["content"]
|
||||
assert error_events[0]["metadata"]["error_type"] == "ValueError"
|
||||
|
||||
|
||||
class TestTokenTrackingDisabled:
|
||||
@pytest.mark.anyio
|
||||
async def test_track_token_usage_false(self):
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
data = j.get_completion_data()
|
||||
assert data["total_tokens"] == 0
|
||||
assert data["llm_call_count"] == 0
|
||||
|
||||
|
||||
class TestConvenienceFields:
|
||||
@pytest.mark.anyio
|
||||
async def test_first_human_message_via_set(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
j.set_first_human_message("What is AI?")
|
||||
data = j.get_completion_data()
|
||||
assert data["first_human_message"] == "What is AI?"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_completion_data(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
j._total_tokens = 100
|
||||
j._msg_count = 5
|
||||
data = j.get_completion_data()
|
||||
assert data["total_tokens"] == 100
|
||||
assert data["message_count"] == 5
|
||||
|
||||
|
||||
class TestMiddlewareEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_record_middleware_uses_middleware_category(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.record_middleware(
|
||||
"title",
|
||||
name="TitleMiddleware",
|
||||
hook="after_model",
|
||||
action="generate_title",
|
||||
changes={"title": "Test Title", "thread_id": "t1"},
|
||||
)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
mw_events = [e for e in events if e["event_type"] == "middleware:title"]
|
||||
assert len(mw_events) == 1
|
||||
assert mw_events[0]["category"] == "middleware"
|
||||
assert mw_events[0]["content"]["name"] == "TitleMiddleware"
|
||||
assert mw_events[0]["content"]["hook"] == "after_model"
|
||||
assert mw_events[0]["content"]["action"] == "generate_title"
|
||||
assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_tag_variants(self, journal_setup):
|
||||
"""Different middleware tags produce distinct event_types."""
|
||||
j, store = journal_setup
|
||||
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
|
||||
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
event_types = {e["event_type"] for e in events}
|
||||
assert "middleware:title" in event_types
|
||||
assert "middleware:guardrail" in event_types
|
||||
|
||||
|
||||
@ -1,196 +0,0 @@
|
||||
"""Tests for RunRepository (SQLAlchemy-backed RunStore).
|
||||
|
||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.run import RunRepository
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return RunRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="pending")
|
||||
row = await repo.get("r1")
|
||||
assert row is not None
|
||||
assert row["run_id"] == "r1"
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["status"] == "pending"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_missing_returns_none(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.get("nope") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.update_status("r1", "running")
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "running"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status_with_error(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.update_status("r1", "error", error="boom")
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "error"
|
||||
assert row["error"] == "boom"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.put("r2", thread_id="t1")
|
||||
await repo.put("r3", thread_id="t2")
|
||||
rows = await repo.list_by_thread("t1")
|
||||
assert len(rows) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in rows)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.delete("r1")
|
||||
assert await repo.get("r1") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.delete("nope") # should not raise
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="pending")
|
||||
await repo.put("r2", thread_id="t1", status="running")
|
||||
await repo.put("r3", thread_id="t2", status="pending")
|
||||
pending = await repo.list_pending()
|
||||
assert len(pending) == 2
|
||||
assert all(r["status"] == "pending" for r in pending)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_completion(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"r1",
|
||||
status="success",
|
||||
total_input_tokens=100,
|
||||
total_output_tokens=50,
|
||||
total_tokens=150,
|
||||
llm_call_count=2,
|
||||
lead_agent_tokens=120,
|
||||
subagent_tokens=20,
|
||||
middleware_tokens=10,
|
||||
message_count=3,
|
||||
last_ai_message="The answer is 42",
|
||||
first_human_message="What is the meaning?",
|
||||
)
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 150
|
||||
assert row["llm_call_count"] == 2
|
||||
assert row["lead_agent_tokens"] == 120
|
||||
assert row["message_count"] == 3
|
||||
assert row["last_ai_message"] == "The answer is 42"
|
||||
assert row["first_human_message"] == "What is the meaning?"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_metadata_preserved(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", metadata={"key": "value"})
|
||||
row = await repo.get("r1")
|
||||
assert row["metadata"] == {"key": "value"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_kwargs_with_non_serializable(self, tmp_path):
|
||||
"""kwargs containing non-JSON-serializable objects should be safely handled."""
|
||||
repo = await _make_repo(tmp_path)
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
await repo.put("r1", thread_id="t1", kwargs={"obj": Dummy()})
|
||||
row = await repo.get("r1")
|
||||
assert "obj" in row["kwargs"]
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_completion_preserves_existing_fields(self, tmp_path):
|
||||
"""update_run_completion does not overwrite thread_id or assistant_id."""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", assistant_id="agent1", status="running")
|
||||
await repo.update_run_completion("r1", status="success", total_tokens=100)
|
||||
row = await repo.get("r1")
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["assistant_id"] == "agent1"
|
||||
assert row["total_tokens"] == 100
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||
"""list_by_thread returns newest first."""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", created_at="2024-01-01T00:00:00+00:00")
|
||||
await repo.put("r2", thread_id="t1", created_at="2024-01-02T00:00:00+00:00")
|
||||
rows = await repo.list_by_thread("t1")
|
||||
assert rows[0]["run_id"] == "r2"
|
||||
assert rows[1]["run_id"] == "r1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_limit(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
for i in range(5):
|
||||
await repo.put(f"r{i}", thread_id="t1")
|
||||
rows = await repo.list_by_thread("t1", limit=2)
|
||||
assert len(rows) == 2
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id=None)
|
||||
assert len(rows) == 2
|
||||
await _cleanup()
|
||||
@ -1,214 +0,0 @@
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint
|
||||
|
||||
|
||||
class FakeCheckpointer:
|
||||
def __init__(self, *, put_result):
|
||||
self.adelete_thread = AsyncMock()
|
||||
self.aput = AsyncMock(return_value=put_result)
|
||||
self.aput_writes = AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
"metadata": {"source": "input"},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", {"content": "first"}),
|
||||
("task-a", "status", "done"),
|
||||
("task-b", "events", {"type": "tool"}),
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
{"source": "input"},
|
||||
{"messages": 3},
|
||||
)
|
||||
assert checkpointer.aput_writes.await_args_list == [
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
[("messages", {"content": "first"}), ("status", "done")],
|
||||
task_id="task-a",
|
||||
),
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
[("events", {"type": "tool"})],
|
||||
task_id="task-b",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||
checkpointer = FakeCheckpointer(put_result=None)
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id=None,
|
||||
pre_run_snapshot=None,
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_awaited_once_with("thread-1")
|
||||
checkpointer.aput.assert_not_awaited()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_when_restore_config_has_no_checkpoint_id():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="did not return checkpoint_id"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [("task-a", "messages", "value")],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": None,
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{"id": "ckpt-1", "channel_versions": {}},
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_on_malformed_pending_write_not_a_tuple():
|
||||
"""pending_writes containing a non-3-tuple item should raise RuntimeError."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", "valid"), # valid
|
||||
["only", "two"], # malformed: only 2 elements
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
# aput succeeded but aput_writes should not be called due to malformed data
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_on_malformed_pending_write_non_string_channel():
|
||||
"""pending_writes containing a non-string channel should raise RuntimeError."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", 123, "value"), # malformed: channel is not a string
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_propagates_aput_writes_failure():
|
||||
"""If aput_writes fails, the exception should propagate (not be swallowed)."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
# Simulate aput_writes failure
|
||||
checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Database connection lost"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", "value"),
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
# aput succeeded, aput_writes was called but failed
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_awaited_once()
|
||||
@ -1,243 +0,0 @@
|
||||
"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import runs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(run_store=None, event_store=None, feedback_repo=None):
|
||||
"""Build a test FastAPI app with stub auth and mocked state."""
|
||||
app = make_authed_test_app()
|
||||
app.include_router(runs.router)
|
||||
|
||||
if run_store is not None:
|
||||
app.state.run_store = run_store
|
||||
if event_store is not None:
|
||||
app.state.run_event_store = event_store
|
||||
if feedback_repo is not None:
|
||||
app.state.feedback_repo = feedback_repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _make_run_store(run_record: dict | None):
|
||||
"""Return an AsyncMock run store whose get() returns run_record."""
|
||||
store = MagicMock()
|
||||
store.get = AsyncMock(return_value=run_record)
|
||||
return store
|
||||
|
||||
|
||||
def _make_event_store(rows: list[dict]):
|
||||
"""Return an AsyncMock event store whose list_messages_by_run() returns rows."""
|
||||
store = MagicMock()
|
||||
store.list_messages_by_run = AsyncMock(return_value=rows)
|
||||
return store
|
||||
|
||||
|
||||
def _make_message(seq: int) -> dict:
|
||||
return {"seq": seq, "event_type": "on_chat_model_stream", "category": "message", "content": f"msg-{seq}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_messages_returns_envelope():
|
||||
"""GET /api/runs/{run_id}/messages returns {data: [...], has_more: bool}."""
|
||||
rows = [_make_message(i) for i in range(1, 4)]
|
||||
run_record = {"run_id": "run-1", "thread_id": "thread-1"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-1/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "data" in body
|
||||
assert "has_more" in body
|
||||
assert body["has_more"] is False
|
||||
assert len(body["data"]) == 3
|
||||
|
||||
|
||||
def test_run_messages_404_when_run_not_found():
|
||||
"""Returns 404 when the run store returns None."""
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(None),
|
||||
event_store=_make_event_store([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/missing-run/messages")
|
||||
assert response.status_code == 404
|
||||
assert "missing-run" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_run_messages_has_more_true_when_extra_row_returned():
|
||||
"""has_more=True when event store returns limit+1 rows."""
|
||||
# Default limit is 50; provide 51 rows
|
||||
rows = [_make_message(i) for i in range(1, 52)] # 51 rows
|
||||
run_record = {"run_id": "run-2", "thread_id": "thread-2"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-2/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["has_more"] is True
|
||||
assert len(body["data"]) == 50 # trimmed to limit
|
||||
|
||||
|
||||
def test_run_messages_passes_after_seq_to_event_store():
|
||||
"""after_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(10)]
|
||||
run_record = {"run_id": "run-3", "thread_id": "thread-3"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-3/messages?after_seq=5")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-3", "run-3",
|
||||
limit=51, # default limit(50) + 1
|
||||
before_seq=None,
|
||||
after_seq=5,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_respects_custom_limit():
|
||||
"""Custom limit is respected and capped at 200."""
|
||||
rows = [_make_message(i) for i in range(1, 6)]
|
||||
run_record = {"run_id": "run-4", "thread_id": "thread-4"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-4/messages?limit=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-4", "run-4",
|
||||
limit=11, # 10 + 1
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_passes_before_seq_to_event_store():
|
||||
"""before_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(3)]
|
||||
run_record = {"run_id": "run-5", "thread_id": "thread-5"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-5/messages?before_seq=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-5", "run-5",
|
||||
limit=51,
|
||||
before_seq=10,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_empty_data():
|
||||
"""Returns empty data list when no messages exist."""
|
||||
run_record = {"run_id": "run-6", "thread_id": "thread-6"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-6/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["data"] == []
|
||||
assert body["has_more"] is False
|
||||
|
||||
|
||||
def _make_feedback_repo(rows: list[dict]):
|
||||
"""Return an AsyncMock feedback repo whose list_by_run() returns rows."""
|
||||
repo = MagicMock()
|
||||
repo.list_by_run = AsyncMock(return_value=rows)
|
||||
return repo
|
||||
|
||||
|
||||
def _make_feedback(run_id: str, idx: int) -> dict:
|
||||
return {"id": f"fb-{idx}", "run_id": run_id, "thread_id": "thread-x", "value": "up"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestRunFeedback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunFeedback:
|
||||
def test_returns_list_of_feedback_dicts(self):
|
||||
"""GET /api/runs/{run_id}/feedback returns a list of feedback dicts."""
|
||||
run_record = {"run_id": "run-fb-1", "thread_id": "thread-fb-1"}
|
||||
rows = [_make_feedback("run-fb-1", i) for i in range(3)]
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
feedback_repo=_make_feedback_repo(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-1/feedback")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert isinstance(body, list)
|
||||
assert len(body) == 3
|
||||
|
||||
def test_404_when_run_not_found(self):
|
||||
"""Returns 404 when run store returns None."""
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(None),
|
||||
feedback_repo=_make_feedback_repo([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/missing-run/feedback")
|
||||
assert response.status_code == 404
|
||||
assert "missing-run" in response.json()["detail"]
|
||||
|
||||
def test_empty_list_when_no_feedback(self):
|
||||
"""Returns empty list when no feedback exists for the run."""
|
||||
run_record = {"run_id": "run-fb-2", "thread_id": "thread-fb-2"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
feedback_repo=_make_feedback_repo([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-2/feedback")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_503_when_feedback_repo_not_configured(self):
|
||||
"""Returns 503 when feedback_repo is None (no DB configured)."""
|
||||
run_record = {"run_id": "run-fb-3", "thread_id": "thread-fb-3"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
)
|
||||
# Explicitly set feedback_repo to None to simulate missing DB
|
||||
app.state.feedback_repo = None
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-3/feedback")
|
||||
assert response.status_code == 503
|
||||
@ -1,178 +0,0 @@
|
||||
"""Tests for ThreadMetaRepository (SQLAlchemy-backed)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return ThreadMetaRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestThreadMetaRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_and_get(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1")
|
||||
assert record["thread_id"] == "t1"
|
||||
assert record["status"] == "idle"
|
||||
assert "created_at" in record
|
||||
|
||||
fetched = await repo.get("t1")
|
||||
assert fetched is not None
|
||||
assert fetched["thread_id"] == "t1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_assistant_id(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1", assistant_id="agent1")
|
||||
assert record["assistant_id"] == "agent1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner_and_display_name(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1", user_id="user1", display_name="My Thread")
|
||||
assert record["user_id"] == "user1"
|
||||
assert record["display_name"] == "My Thread"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_metadata(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
record = await repo.create("t1", metadata={"key": "value"})
|
||||
assert record["metadata"] == {"key": "value"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nonexistent(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.get("nonexistent") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_record_allows(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.check_access("unknown", "user1") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_matches(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_owner_mismatch(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2") is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
# Explicit user_id=None to bypass the new AUTO default that
|
||||
# would otherwise pick up the test user from the autouse fixture.
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone") is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_missing_row_denied(self, tmp_path):
|
||||
"""require_existing=True flips the missing-row case to *denied*.
|
||||
|
||||
Closes the delete-idempotence cross-user gap: after a thread is
|
||||
deleted, the row is gone, and the permissive default would let any
|
||||
caller "claim" it as untracked. The strict mode demands a row.
|
||||
"""
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.check_access("never-existed", "user1", require_existing=True) is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id="user1")
|
||||
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
||||
"""Even in strict mode, a row with NULL user_id stays shared.
|
||||
|
||||
The strict flag tightens the *missing row* case, not the *shared
|
||||
row* case — legacy pre-auth rows that survived a clean migration
|
||||
without an owner are still everyone's.
|
||||
"""
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", user_id=None)
|
||||
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1")
|
||||
await repo.update_status("t1", "busy")
|
||||
record = await repo.get("t1")
|
||||
assert record["status"] == "busy"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1")
|
||||
await repo.delete("t1")
|
||||
assert await repo.get("t1") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.delete("nonexistent") # should not raise
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_merges(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1", metadata={"a": 1, "b": 2})
|
||||
await repo.update_metadata("t1", {"b": 99, "c": 3})
|
||||
record = await repo.get("t1")
|
||||
# Existing key preserved, overlapping key overwritten, new key added
|
||||
assert record["metadata"] == {"a": 1, "b": 99, "c": 3}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_on_empty(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.create("t1")
|
||||
await repo.update_metadata("t1", {"k": "v"})
|
||||
record = await repo.get("t1")
|
||||
assert record["metadata"] == {"k": "v"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_metadata_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
||||
await _cleanup()
|
||||
@ -1,110 +0,0 @@
|
||||
"""Tests for runtime.user_context — contextvar three-state semantics.
|
||||
|
||||
These tests opt out of the autouse contextvar fixture (added in
|
||||
commit 6) because they explicitly test the cases where the contextvar
|
||||
is set or unset.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.user_context import (
|
||||
CurrentUser,
|
||||
DEFAULT_USER_ID,
|
||||
get_current_user,
|
||||
get_effective_user_id,
|
||||
require_current_user,
|
||||
reset_current_user,
|
||||
set_current_user,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_default_is_none():
|
||||
"""Before any set, contextvar returns None."""
|
||||
assert get_current_user() is None
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_set_and_reset_roundtrip():
|
||||
"""set_current_user returns a token that reset restores."""
|
||||
user = SimpleNamespace(id="user-1")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_current_user() is user
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
assert get_current_user() is None
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_require_current_user_raises_when_unset():
|
||||
"""require_current_user raises RuntimeError if contextvar is unset."""
|
||||
assert get_current_user() is None
|
||||
with pytest.raises(RuntimeError, match="without user context"):
|
||||
require_current_user()
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_require_current_user_returns_user_when_set():
|
||||
"""require_current_user returns the user when contextvar is set."""
|
||||
user = SimpleNamespace(id="user-2")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert require_current_user() is user
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_protocol_accepts_duck_typed():
|
||||
"""CurrentUser is a runtime_checkable Protocol matching any .id-bearing object."""
|
||||
user = SimpleNamespace(id="user-3")
|
||||
assert isinstance(user, CurrentUser)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_protocol_rejects_no_id():
|
||||
"""Objects without .id do not satisfy CurrentUser Protocol."""
|
||||
not_a_user = SimpleNamespace(email="no-id@example.com")
|
||||
assert not isinstance(not_a_user, CurrentUser)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_user_id / DEFAULT_USER_ID tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_default_user_id_is_default():
|
||||
assert DEFAULT_USER_ID == "default"
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_returns_default_when_no_user():
|
||||
"""No user in context -> fallback to DEFAULT_USER_ID."""
|
||||
assert get_effective_user_id() == "default"
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_returns_user_id_when_set():
|
||||
user = SimpleNamespace(id="u-abc-123")
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_effective_user_id() == "u-abc-123"
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
|
||||
@pytest.mark.no_auto_user
|
||||
def test_effective_user_id_coerces_to_str():
|
||||
"""User.id might be a UUID object; must come back as str."""
|
||||
import uuid
|
||||
uid = uuid.uuid4()
|
||||
|
||||
user = SimpleNamespace(id=uid)
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
assert get_effective_user_id() == str(uid)
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
@ -1,12 +1,9 @@
|
||||
"""Helpers for router-level tests that need a stubbed auth context.
|
||||
"""Helpers for router-level tests that need an authenticated request.
|
||||
|
||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||
decorators that read ``request.state.auth`` and call
|
||||
``thread_store.check_access``. Router-level unit tests construct
|
||||
**bare** FastAPI apps that include only one router — they have neither
|
||||
the auth middleware nor a real thread_store, so the decorators raise
|
||||
401 (TestClient path) or ValueError (direct-call path).
|
||||
The production gateway stamps ``request.user`` / ``request.auth`` in the
|
||||
auth middleware, then route decorators read that authenticated context.
|
||||
Router-level unit tests build very small FastAPI apps that include only
|
||||
one router, so they need a lightweight stand-in for that middleware.
|
||||
|
||||
This module provides two surfaces:
|
||||
|
||||
@ -15,10 +12,9 @@ This module provides two surfaces:
|
||||
request, plus a permissive ``thread_store`` mock on
|
||||
``app.state``. Use from TestClient-based router tests.
|
||||
|
||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||
the ``@require_permission`` decorator chain by walking ``__wrapped__``.
|
||||
Use from direct-call tests that previously imported the route
|
||||
function and called it positionally.
|
||||
2. :func:`call_unwrapped` — invokes the underlying function by walking
|
||||
``__wrapped__``. Use from direct-call tests that want to bypass the
|
||||
route decorators entirely.
|
||||
|
||||
Both helpers are deliberately permissive: they never deny a request.
|
||||
Tests that want to verify the *auth boundary itself* (e.g.
|
||||
@ -37,8 +33,8 @@ from fastapi import FastAPI, Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import AuthContext, Permissions
|
||||
from app.plugins.auth.domain.models import User
|
||||
from app.plugins.auth.authorization import AuthContext, Permissions
|
||||
|
||||
# Default permission set granted to the stub user. Mirrors `_ALL_PERMISSIONS`
|
||||
# in authz.py — kept inline so the tests don't import a private symbol.
|
||||
@ -63,12 +59,7 @@ def _make_stub_user() -> User:
|
||||
|
||||
|
||||
class _StubAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Stamp a fake user / AuthContext onto every request.
|
||||
|
||||
Mirrors what production ``AuthMiddleware`` does after the JWT decode
|
||||
+ DB lookup short-circuit, so ``@require_permission`` finds an
|
||||
authenticated context and skips its own re-authentication path.
|
||||
"""
|
||||
"""Stamp a fake user / AuthContext onto every request."""
|
||||
|
||||
def __init__(self, app: ASGIApp, user_factory: Callable[[], User]) -> None:
|
||||
super().__init__(app)
|
||||
@ -76,8 +67,11 @@ class _StubAuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
user = self._user_factory()
|
||||
auth_context = AuthContext(user=user, permissions=list(_STUB_PERMISSIONS))
|
||||
request.scope["user"] = user
|
||||
request.scope["auth"] = auth_context
|
||||
request.state.user = user
|
||||
request.state.auth = AuthContext(user=user, permissions=list(_STUB_PERMISSIONS))
|
||||
request.state.auth = auth_context
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@ -93,9 +87,8 @@ def make_authed_test_app(
|
||||
populated :class:`User`. Useful for cross-user isolation tests
|
||||
that need a stable id across requests.
|
||||
owner_check_passes: When True (default), ``thread_store.check_access``
|
||||
returns True for every call so ``@require_permission(owner_check=True)``
|
||||
never blocks the route under test. Pass False to verify that
|
||||
permission failures surface correctly.
|
||||
returns True for every call so owner-gated routes do not block
|
||||
the handler under test. Pass False to verify denial paths.
|
||||
|
||||
Returns:
|
||||
A ``FastAPI`` app with the stub middleware installed and
|
||||
@ -121,12 +114,9 @@ def call_unwrapped(decorated: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.
|
||||
"""Invoke the underlying function of a ``@require_permission``-decorated route.
|
||||
|
||||
``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all
|
||||
the way down to the original handler, bypassing every authz +
|
||||
require_auth wrapper. Use from tests that need to call route
|
||||
functions directly (without TestClient) and don't want to construct
|
||||
a fake ``Request`` just to satisfy the decorator. The ``ParamSpec``
|
||||
propagates the wrapped route's signature so call sites still get
|
||||
parameter checking despite the unwrapping.
|
||||
the way down to the original handler. Use from tests that call route
|
||||
functions directly and do not want to build a full request/middleware
|
||||
stack.
|
||||
"""
|
||||
fn: Callable = decorated
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
@ -1,22 +1,22 @@
|
||||
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
|
||||
"""Tests for authentication module: JWT, password hashing, and auth context behavior."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import (
|
||||
from app.plugins.auth.authorization import (
|
||||
AuthContext,
|
||||
Permissions,
|
||||
get_auth_context,
|
||||
require_auth,
|
||||
require_permission,
|
||||
)
|
||||
from app.plugins.auth.authorization.hooks import build_authz_hooks
|
||||
from app.plugins.auth.domain import create_access_token, decode_token, hash_password, verify_password
|
||||
from app.plugins.auth.domain.models import User
|
||||
from store.persistence import MappedBase
|
||||
|
||||
# ── Password Hashing ────────────────────────────────────────────────────────
|
||||
|
||||
@ -67,7 +67,7 @@ def test_create_and_decode_token():
|
||||
|
||||
def test_decode_token_expired():
|
||||
"""Expired token returns TokenError.EXPIRED."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.plugins.auth.domain.errors import TokenError
|
||||
|
||||
user_id = str(uuid4())
|
||||
# Create token that expires immediately
|
||||
@ -78,7 +78,7 @@ def test_decode_token_expired():
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Invalid token returns TokenError."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.plugins.auth.domain.errors import TokenError
|
||||
|
||||
assert isinstance(decode_token("not.a.valid.token"), TokenError)
|
||||
assert isinstance(decode_token(""), TokenError)
|
||||
@ -101,6 +101,8 @@ def test_auth_context_unauthenticated():
|
||||
"""AuthContext with no user."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
assert ctx.is_authenticated is False
|
||||
assert ctx.principal_id is None
|
||||
assert ctx.capabilities == ()
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
@ -109,6 +111,8 @@ def test_auth_context_authenticated_no_perms():
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
assert ctx.is_authenticated is True
|
||||
assert ctx.principal_id == str(user.id)
|
||||
assert ctx.capabilities == ()
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
@ -117,6 +121,7 @@ def test_auth_context_has_permission():
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
|
||||
ctx = AuthContext(user=user, permissions=perms)
|
||||
assert ctx.capabilities == tuple(perms)
|
||||
assert ctx.has_permission("threads", "read") is True
|
||||
assert ctx.has_permission("threads", "write") is True
|
||||
assert ctx.has_permission("threads", "delete") is False
|
||||
@ -162,79 +167,12 @@ def test_get_auth_context_set():
|
||||
assert get_auth_context(mock_request) == ctx
|
||||
|
||||
|
||||
# ── require_auth decorator ──────────────────────────────────────────────────
|
||||
def test_register_app_sets_default_authz_hooks():
|
||||
from app.gateway.registrar import register_app
|
||||
|
||||
app = register_app()
|
||||
|
||||
def test_require_auth_sets_auth_context():
|
||||
"""require_auth sets auth context on request from cookie."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_auth
|
||||
async def endpoint(request: Request):
|
||||
ctx = get_auth_context(request)
|
||||
return {"authenticated": ctx.is_authenticated}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# No cookie → anonymous
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["authenticated"] is False
|
||||
|
||||
|
||||
def test_require_auth_requires_request_param():
|
||||
"""require_auth raises ValueError if request parameter is missing."""
|
||||
import asyncio
|
||||
|
||||
@require_auth
|
||||
async def bad_endpoint(): # Missing `request` parameter
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
|
||||
asyncio.run(bad_endpoint())
|
||||
|
||||
|
||||
# ── require_permission decorator ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_permission_requires_auth():
|
||||
"""require_permission raises 401 when not authenticated."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "read")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_require_permission_denies_wrong_permission():
|
||||
"""User without required permission gets 403."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "delete")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 403
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
assert app.state.authz_hooks == build_authz_hooks()
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
@ -271,45 +209,55 @@ def test_sqlite_round_trip_new_fields():
|
||||
import asyncio
|
||||
import tempfile
|
||||
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from app.plugins.auth.storage import DbUserRepository, UserCreate
|
||||
|
||||
async def _run() -> None:
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
url = f"sqlite+aiosqlite:///{tmpdir}/scratch.db"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=tmpdir)
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmpdir}/scratch.db", future=True)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
try:
|
||||
repo = SQLiteUserRepository(get_session_factory())
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = await repo.create_user(user)
|
||||
async with session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
created = await repo.create_user(
|
||||
UserCreate(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
|
||||
fetched = await repo.get_user_by_email("setup@test.com")
|
||||
async with session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
fetched = await repo.get_user_by_email("setup@test.com")
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
await repo.update_user(fetched)
|
||||
refetched = await repo.get_user_by_id(str(fetched.id))
|
||||
updated = fetched.model_copy(update={"needs_setup": False, "token_version": 4})
|
||||
async with session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
await repo.update_user(updated)
|
||||
await session.commit()
|
||||
async with session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
refetched = await repo.get_user_by_id(fetched.id)
|
||||
assert refetched is not None
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
await close_engine()
|
||||
await engine.dispose()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
@ -320,49 +268,54 @@ def test_update_user_raises_when_row_concurrently_deleted(tmp_path):
|
||||
Earlier the SQLite repo returned the input unchanged when the row was
|
||||
missing, making a phantom success path that admin password reset
|
||||
callers (`reset_admin`, `_ensure_admin_user`) would happily log as
|
||||
'password reset'. The new contract: raise ``UserNotFoundError`` so
|
||||
'password reset'. The new contract: raise ``LookupError`` so
|
||||
a vanished row never looks like a successful update.
|
||||
"""
|
||||
import asyncio
|
||||
import tempfile
|
||||
|
||||
from app.gateway.auth.repositories.base import UserNotFoundError
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from app.plugins.auth.storage import DbUserRepository, UserCreate
|
||||
|
||||
async def _run() -> None:
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine,
|
||||
)
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
from app.plugins.auth.storage.models import User as UserModel
|
||||
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
url = f"sqlite+aiosqlite:///{d}/scratch.db"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=d)
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{d}/scratch.db", future=True)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
sf = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
try:
|
||||
sf = get_session_factory()
|
||||
repo = SQLiteUserRepository(sf)
|
||||
user = User(
|
||||
email="ghost@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="user",
|
||||
)
|
||||
created = await repo.create_user(user)
|
||||
async with sf() as session:
|
||||
repo = DbUserRepository(session)
|
||||
created = await repo.create_user(
|
||||
UserCreate(
|
||||
email="ghost@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="user",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Simulate "row vanished underneath us" by deleting the row
|
||||
# via the raw ORM session, then attempt to update.
|
||||
async with sf() as session:
|
||||
row = await session.get(UserRow, str(created.id))
|
||||
row = await session.get(UserModel, created.id)
|
||||
assert row is not None
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
created.needs_setup = True
|
||||
with pytest.raises(UserNotFoundError):
|
||||
await repo.update_user(created)
|
||||
updated = created.model_copy(update={"needs_setup": True})
|
||||
async with sf() as session:
|
||||
repo = DbUserRepository(session)
|
||||
with pytest.raises(LookupError):
|
||||
await repo.update_user(updated)
|
||||
finally:
|
||||
await close_engine()
|
||||
await engine.dispose()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
@ -374,7 +327,7 @@ def test_jwt_encodes_ver():
|
||||
"""JWT payload includes ver field."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.plugins.auth.domain.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()), token_version=3)
|
||||
@ -387,7 +340,7 @@ def test_jwt_default_ver_zero():
|
||||
"""JWT ver defaults to 0."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.plugins.auth.domain.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()))
|
||||
@ -398,30 +351,34 @@ def test_jwt_default_ver_zero():
|
||||
|
||||
def test_token_version_mismatch_rejects():
|
||||
"""Token with stale ver is rejected by get_current_user_from_request."""
|
||||
import asyncio
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, token_version=0)
|
||||
request = SimpleNamespace(
|
||||
cookies={"access_token": token},
|
||||
state=SimpleNamespace(
|
||||
_auth_session=MagicMock(),
|
||||
),
|
||||
)
|
||||
stale_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||
request.state._auth_session.__aenter__ = AsyncMock(return_value=request.state._auth_session)
|
||||
request.state._auth_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {"access_token": token}
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_user = AsyncMock(return_value=mock_user)
|
||||
mock_provider_fn.return_value = mock_provider
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
with patch(
|
||||
"app.plugins.auth.security.dependencies.DbUserRepository.get_user_by_id",
|
||||
new=AsyncMock(return_value=stale_user),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
import asyncio
|
||||
|
||||
asyncio.run(get_current_user_from_request(request))
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
# ── change-password extension ──────────────────────────────────────────────
|
||||
@ -429,7 +386,7 @@ def test_token_version_mismatch_rejects():
|
||||
|
||||
def test_change_password_request_accepts_new_email():
|
||||
"""ChangePasswordRequest model accepts optional new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
from app.plugins.auth.api.schemas import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(
|
||||
current_password="old",
|
||||
@ -441,7 +398,7 @@ def test_change_password_request_accepts_new_email():
|
||||
|
||||
def test_change_password_request_new_email_optional():
|
||||
"""ChangePasswordRequest model works without new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
from app.plugins.auth.api.schemas import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
|
||||
assert req.new_email is None
|
||||
@ -449,7 +406,7 @@ def test_change_password_request_new_email_optional():
|
||||
|
||||
def test_login_response_includes_needs_setup():
|
||||
"""LoginResponse includes needs_setup field."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
from app.plugins.auth.api.schemas import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=3600, needs_setup=True)
|
||||
assert resp.needs_setup is True
|
||||
@ -462,7 +419,7 @@ def test_login_response_includes_needs_setup():
|
||||
|
||||
def test_rate_limiter_allows_under_limit():
|
||||
"""Requests under the limit are allowed."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
|
||||
from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts
|
||||
|
||||
_login_attempts.clear()
|
||||
_check_rate_limit("192.168.1.1") # Should not raise
|
||||
@ -470,7 +427,7 @@ def test_rate_limiter_allows_under_limit():
|
||||
|
||||
def test_rate_limiter_blocks_after_max_failures():
|
||||
"""IP is blocked after 5 consecutive failures."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
|
||||
from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts, _record_login_failure
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.1"
|
||||
@ -483,7 +440,7 @@ def test_rate_limiter_blocks_after_max_failures():
|
||||
|
||||
def test_rate_limiter_resets_on_success():
|
||||
"""Successful login clears the failure counter."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||
from app.plugins.auth.api.schemas import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.2"
|
||||
@ -499,7 +456,7 @@ def test_rate_limiter_resets_on_success():
|
||||
def test_get_client_ip_direct_connection_no_proxy(monkeypatch):
|
||||
"""Direct mode (no AUTH_TRUSTED_PROXIES): use TCP peer regardless of X-Real-IP."""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
from app.plugins.auth.api.schemas import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.42"
|
||||
@ -514,7 +471,7 @@ def test_get_client_ip_x_real_ip_ignored_when_no_trusted_proxy(monkeypatch):
|
||||
request to dodge per-IP rate limits in dev / direct mode.
|
||||
"""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
from app.plugins.auth.api.schemas import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "127.0.0.1"
|
||||
@ -525,7 +482,7 @@ def test_get_client_ip_x_real_ip_ignored_when_no_trusted_proxy(monkeypatch):
|
||||
def test_get_client_ip_x_real_ip_honored_from_trusted_proxy(monkeypatch):
|
||||
"""X-Real-IP is honored when the TCP peer matches AUTH_TRUSTED_PROXIES."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
from app.plugins.auth.api.schemas import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.5.6.7" # in trusted CIDR
|
||||
@ -536,7 +493,7 @@ def test_get_client_ip_x_real_ip_honored_from_trusted_proxy(monkeypatch):
|
||||
def test_get_client_ip_x_real_ip_rejected_from_untrusted_peer(monkeypatch):
|
||||
"""X-Real-IP is rejected when the TCP peer is NOT in the trusted list."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
from app.plugins.auth.api.schemas import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "8.8.8.8" # NOT in trusted CIDR
|
||||
@ -547,7 +504,7 @@ def test_get_client_ip_x_real_ip_rejected_from_untrusted_peer(monkeypatch):
|
||||
def test_get_client_ip_xff_never_honored(monkeypatch):
|
||||
"""X-Forwarded-For is never used; only X-Real-IP from a trusted peer."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
from app.plugins.auth.api.schemas import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1"
|
||||
@ -558,7 +515,7 @@ def test_get_client_ip_xff_never_honored(monkeypatch):
|
||||
def test_get_client_ip_invalid_trusted_proxy_entry_skipped(monkeypatch, caplog):
|
||||
"""Garbage entries in AUTH_TRUSTED_PROXIES are warned and skipped."""
|
||||
monkeypatch.setenv("AUTH_TRUSTED_PROXIES", "not-an-ip,10.0.0.0/8")
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
from app.plugins.auth.api.schemas import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.5.6.7"
|
||||
@ -569,7 +526,7 @@ def test_get_client_ip_invalid_trusted_proxy_entry_skipped(monkeypatch, caplog):
|
||||
def test_get_client_ip_no_client_returns_unknown(monkeypatch):
|
||||
"""No request.client → 'unknown' marker (no crash)."""
|
||||
monkeypatch.delenv("AUTH_TRUSTED_PROXIES", raising=False)
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
from app.plugins.auth.api.schemas import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client = None
|
||||
@ -584,7 +541,7 @@ def test_register_rejects_literal_password():
|
||||
"""Pydantic validator rejects 'password' as a registration password."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
from app.plugins.auth.api.schemas import RegisterRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
RegisterRequest(email="x@example.com", password="password")
|
||||
@ -595,7 +552,7 @@ def test_register_rejects_common_password_case_insensitive():
|
||||
"""Case variants of common passwords are also rejected."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
from app.plugins.auth.api.schemas import RegisterRequest
|
||||
|
||||
for variant in ["PASSWORD", "Password1", "qwerty123", "letmein1"]:
|
||||
with pytest.raises(ValidationError):
|
||||
@ -604,7 +561,7 @@ def test_register_rejects_common_password_case_insensitive():
|
||||
|
||||
def test_register_accepts_strong_password():
|
||||
"""A non-blocklisted password of length >=8 is accepted."""
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
from app.plugins.auth.api.schemas import RegisterRequest
|
||||
|
||||
req = RegisterRequest(email="x@example.com", password="Tr0ub4dor&3-Horse")
|
||||
assert req.password == "Tr0ub4dor&3-Horse"
|
||||
@ -614,7 +571,7 @@ def test_change_password_rejects_common_password():
|
||||
"""The same blocklist applies to change-password."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
from app.plugins.auth.api.schemas import ChangePasswordRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ChangePasswordRequest(current_password="anything", new_password="iloveyou")
|
||||
@ -624,7 +581,7 @@ def test_password_blocklist_keeps_short_passwords_for_length_check():
|
||||
"""Short passwords still fail the min_length check (not the blocklist)."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.routers.auth import RegisterRequest
|
||||
from app.plugins.auth.api.schemas import RegisterRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
RegisterRequest(email="x@example.com", password="abc")
|
||||
@ -639,16 +596,17 @@ def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
|
||||
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as config_module
|
||||
import app.plugins.auth.runtime.config_state as config_module
|
||||
from app.plugins.auth.runtime.config_state import reset_auth_config
|
||||
|
||||
config_module._auth_config = None
|
||||
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
reset_auth_config()
|
||||
config = config_module.get_auth_config()
|
||||
|
||||
assert config.jwt_secret # non-empty ephemeral secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
|
||||
# Cleanup
|
||||
config_module._auth_config = None
|
||||
reset_auth_config()
|
||||
@ -5,7 +5,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.gateway.auth.config import AuthConfig
|
||||
from app.plugins.auth.domain.config import AuthConfig
|
||||
from app.plugins.auth.runtime.config_state import reset_auth_config
|
||||
|
||||
|
||||
def test_auth_config_defaults():
|
||||
@ -25,30 +26,28 @@ def test_auth_config_token_expiry_range():
|
||||
def test_auth_config_from_env():
|
||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import app.gateway.auth.config as cfg
|
||||
import app.plugins.auth.runtime.config_state as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
reset_auth_config()
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret == "test-jwt-secret-from-env"
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
reset_auth_config()
|
||||
|
||||
|
||||
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
import app.plugins.auth.runtime.config_state as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
reset_auth_config()
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
reset_auth_config()
|
||||
157
backend/tests/unittest/test_auth_dependencies.py
Normal file
157
backend/tests/unittest/test_auth_dependencies.py
Normal file
@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.plugins.auth.domain.config import AuthConfig
|
||||
from app.plugins.auth.security.dependencies import (
|
||||
get_current_user_from_request,
|
||||
get_current_user_id,
|
||||
get_optional_user_from_request,
|
||||
)
|
||||
from app.plugins.auth.domain.jwt import create_access_token
|
||||
from app.plugins.auth.runtime.config_state import set_auth_config
|
||||
from app.plugins.auth.storage import DbUserRepository, UserCreate
|
||||
from store.persistence import MappedBase
|
||||
from app.plugins.auth.storage.models import User as UserModel # noqa: F401
|
||||
|
||||
_TEST_SECRET = "test-secret-auth-dependencies-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
async def _make_request(tmp_path, *, cookie: str | None = None, users: list[UserCreate] | None = None):
|
||||
engine = create_async_engine(
|
||||
f"sqlite+aiosqlite:///{tmp_path / 'auth-deps.db'}",
|
||||
future=True,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
session = session_factory()
|
||||
if users:
|
||||
repo = DbUserRepository(session)
|
||||
for user in users:
|
||||
await repo.create_user(user)
|
||||
await session.commit()
|
||||
request = SimpleNamespace(
|
||||
cookies={"access_token": cookie} if cookie is not None else {},
|
||||
state=SimpleNamespace(_auth_session=session),
|
||||
)
|
||||
return request, session, engine
|
||||
|
||||
|
||||
class TestAuthDependencies:
|
||||
@pytest.mark.anyio
|
||||
async def test_no_cookie_returns_401(self, tmp_path):
|
||||
request, session, engine = await _make_request(tmp_path)
|
||||
try:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_from_request(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail["code"] == "not_authenticated"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_token_returns_401(self, tmp_path):
|
||||
request, session, engine = await _make_request(tmp_path, cookie="garbage")
|
||||
try:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_from_request(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail["code"] == "token_invalid"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_missing_user_returns_401(self, tmp_path):
|
||||
token = create_access_token("missing-user", token_version=0)
|
||||
request, session, engine = await _make_request(tmp_path, cookie=token)
|
||||
try:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_from_request(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail["code"] == "user_not_found"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_version_mismatch_returns_401(self, tmp_path):
|
||||
token = create_access_token("user-1", token_version=0)
|
||||
request, session, engine = await _make_request(
|
||||
tmp_path,
|
||||
cookie=token,
|
||||
users=[
|
||||
UserCreate(
|
||||
id="user-1",
|
||||
email="user1@example.com",
|
||||
token_version=2,
|
||||
)
|
||||
],
|
||||
)
|
||||
try:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user_from_request(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail["code"] == "token_invalid"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_valid_token_returns_user(self, tmp_path):
|
||||
token = create_access_token("user-2", token_version=3)
|
||||
request, session, engine = await _make_request(
|
||||
tmp_path,
|
||||
cookie=token,
|
||||
users=[
|
||||
UserCreate(
|
||||
id="user-2",
|
||||
email="user2@example.com",
|
||||
token_version=3,
|
||||
)
|
||||
],
|
||||
)
|
||||
try:
|
||||
user = await get_current_user_from_request(request)
|
||||
user_id = await get_current_user_id(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
assert user.id == "user-2"
|
||||
assert user.email == "user2@example.com"
|
||||
assert user_id == "user-2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_optional_user_returns_none_on_failure(self, tmp_path):
|
||||
request, session, engine = await _make_request(tmp_path, cookie="bad-token")
|
||||
try:
|
||||
user = await get_optional_user_from_request(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
assert user is None
|
||||
@ -4,9 +4,10 @@ from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
from app.plugins.auth.domain.config import AuthConfig
|
||||
from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.plugins.auth.domain.jwt import create_access_token, decode_token
|
||||
from app.plugins.auth.runtime.config_state import set_auth_config
|
||||
|
||||
|
||||
def test_auth_error_code_values():
|
||||
@ -56,7 +57,7 @@ def test_decode_token_returns_token_error_on_expired():
|
||||
def test_decode_token_returns_token_error_on_bad_signature():
|
||||
_setup_config()
|
||||
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
token = pyjwt.encode(payload, "wrong-secret-key-for-tests-minimum-32", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||
from app.plugins.auth.security.middleware import AuthMiddleware, _is_public
|
||||
|
||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||
|
||||
@ -165,7 +167,8 @@ def test_protected_path_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
||||
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
||||
tokens through to the route handler."""
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
client.cookies.set("access_token", "some-token")
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
@ -220,3 +223,44 @@ def test_unknown_endpoint_with_junk_cookie_rejected(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_middleware_populates_request_user_and_auth(monkeypatch):
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from app.plugins.auth.security import middleware as middleware_module
|
||||
|
||||
user = SimpleNamespace(id="user-123", email="test@example.com")
|
||||
|
||||
async def _fake_get_current_user_from_request(request):
|
||||
return user
|
||||
|
||||
monkeypatch.setattr(
|
||||
middleware_module,
|
||||
"get_current_user_from_request",
|
||||
_fake_get_current_user_from_request,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
@app.get("/api/models")
|
||||
async def models_get(request: Request):
|
||||
return {
|
||||
"request_user_id": request.user.id,
|
||||
"state_user_id": request.state.user.id,
|
||||
"request_auth_user_id": request.auth.user.id,
|
||||
"state_auth_user_id": request.state.auth.user.id,
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
client.cookies.set("access_token", "valid")
|
||||
response = client.get("/api/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"request_user_id": "user-123",
|
||||
"state_user_id": "user-123",
|
||||
"request_auth_user_id": "user-123",
|
||||
"state_auth_user_id": "user-123",
|
||||
}
|
||||
97
backend/tests/unittest/test_auth_policies.py
Normal file
97
backend/tests/unittest/test_auth_policies.py
Normal file
@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from app.plugins.auth.authorization import AuthContext, Permissions
|
||||
from app.plugins.auth.authorization.policies import require_thread_owner
|
||||
from app.plugins.auth.domain.models import User
|
||||
|
||||
|
||||
def _make_auth_context() -> AuthContext:
|
||||
user = User(id=uuid4(), email="user@example.com", password_hash="hash")
|
||||
return AuthContext(user=user, permissions=[Permissions.THREADS_READ, Permissions.RUNS_READ])
|
||||
|
||||
|
||||
def _make_request(*, thread_repo, run_repo=None, checkpointer=None) -> Request:
|
||||
app = SimpleNamespace(
|
||||
state=SimpleNamespace(
|
||||
thread_meta_repo=thread_repo,
|
||||
run_store=run_repo,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
)
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/threads/thread-1/runs",
|
||||
"headers": [],
|
||||
"app": app,
|
||||
"route": SimpleNamespace(path="/api/threads/{thread_id}/runs"),
|
||||
"path_params": {"thread_id": "thread-1"},
|
||||
}
|
||||
return Request(scope)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_require_thread_owner_uses_thread_row_user_id() -> None:
|
||||
auth = _make_auth_context()
|
||||
thread_repo = SimpleNamespace(
|
||||
get_thread_meta=AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
user_id=str(auth.user.id),
|
||||
metadata={"user_id": "someone-else"},
|
||||
)
|
||||
)
|
||||
)
|
||||
request = _make_request(thread_repo=thread_repo)
|
||||
|
||||
await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_require_thread_owner_falls_back_to_user_owned_runs() -> None:
|
||||
auth = _make_auth_context()
|
||||
thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None))
|
||||
run_repo = SimpleNamespace(
|
||||
list_by_thread=AsyncMock(return_value=[{"run_id": "run-1", "thread_id": "thread-1"}])
|
||||
)
|
||||
request = _make_request(thread_repo=thread_repo, run_repo=run_repo)
|
||||
|
||||
await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True)
|
||||
|
||||
run_repo.list_by_thread.assert_awaited_once_with("thread-1", limit=1, user_id=str(auth.user.id))
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_require_thread_owner_falls_back_to_checkpoint_threads() -> None:
|
||||
auth = _make_auth_context()
|
||||
thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None))
|
||||
run_repo = SimpleNamespace(list_by_thread=AsyncMock(return_value=[]))
|
||||
checkpointer = SimpleNamespace(aget_tuple=AsyncMock(return_value=object()))
|
||||
request = _make_request(thread_repo=thread_repo, run_repo=run_repo, checkpointer=checkpointer)
|
||||
|
||||
await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True)
|
||||
|
||||
checkpointer.aget_tuple.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_require_thread_owner_denies_missing_thread() -> None:
|
||||
auth = _make_auth_context()
|
||||
thread_repo = SimpleNamespace(get_thread_meta=AsyncMock(return_value=None))
|
||||
run_repo = SimpleNamespace(list_by_thread=AsyncMock(return_value=[]))
|
||||
checkpointer = SimpleNamespace(aget_tuple=AsyncMock(return_value=None))
|
||||
request = _make_request(thread_repo=thread_repo, run_repo=run_repo, checkpointer=checkpointer)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await require_thread_owner(request, auth, thread_id="thread-1", require_existing=True)
|
||||
|
||||
assert getattr(exc_info.value, "status_code", None) == 404
|
||||
assert getattr(exc_info.value, "detail", "") == "Thread thread-1 not found"
|
||||
146
backend/tests/unittest/test_auth_route_injection.py
Normal file
146
backend/tests/unittest/test_auth_route_injection.py
Normal file
@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from starlette.requests import Request
|
||||
|
||||
from app.plugins.auth.authorization import AuthContext
|
||||
from app.plugins.auth.domain.models import User
|
||||
from app.plugins.auth.injection import load_route_policy_registry, validate_route_policy_registry
|
||||
from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry, RoutePolicySpec
|
||||
from app.plugins.auth.injection.route_guard import enforce_route_policy
|
||||
from app.plugins.auth.injection.route_injector import install_route_guards
|
||||
|
||||
|
||||
def test_load_route_policy_registry_flattens_yaml_sections() -> None:
|
||||
registry = load_route_policy_registry()
|
||||
|
||||
public_spec = registry.get("POST", "/api/v1/auth/login/local")
|
||||
assert public_spec is not None
|
||||
assert public_spec.public is True
|
||||
|
||||
run_stream_spec = registry.get("GET", "/api/threads/{thread_id}/runs/{run_id}/stream")
|
||||
assert run_stream_spec is not None
|
||||
assert run_stream_spec.capability == "runs:read"
|
||||
assert run_stream_spec.policies == ("owner:run",)
|
||||
|
||||
post_stream_spec = registry.get("POST", "/api/threads/{thread_id}/runs/{run_id}/stream")
|
||||
assert post_stream_spec == run_stream_spec
|
||||
|
||||
|
||||
def test_validate_route_policy_registry_rejects_missing_entry() -> None:
|
||||
app = FastAPI()
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/api/needs-policy")
|
||||
async def needs_policy() -> dict[str, bool]:
|
||||
return {"ok": True}
|
||||
|
||||
app.include_router(router)
|
||||
registry = RoutePolicyRegistry([])
|
||||
|
||||
with pytest.raises(RuntimeError, match="Missing route policy entries"):
|
||||
validate_route_policy_registry(app, registry)
|
||||
|
||||
|
||||
def test_install_route_guards_appends_route_dependency() -> None:
|
||||
app = FastAPI()
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/api/demo")
|
||||
async def demo() -> dict[str, bool]:
|
||||
return {"ok": True}
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
route = next(route for route in app.routes if getattr(route, "path", None) == "/api/demo")
|
||||
before = len(route.dependencies)
|
||||
|
||||
install_route_guards(app)
|
||||
|
||||
assert len(route.dependencies) == before + 1
|
||||
assert route.dependencies[-1].dependency is enforce_route_policy
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_enforce_route_policy_denies_missing_capability() -> None:
|
||||
user = User(id=uuid4(), email="user@example.com", password_hash="hash")
|
||||
auth = AuthContext(user=user, permissions=["threads:read"])
|
||||
registry = RoutePolicyRegistry(
|
||||
[
|
||||
SimpleNamespace(
|
||||
method="GET",
|
||||
path="/api/threads/{thread_id}/uploads/list",
|
||||
spec=RoutePolicySpec(capability="threads:delete"),
|
||||
matches_request=lambda *_args, **_kwargs: True,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
app = SimpleNamespace(state=SimpleNamespace(auth_route_policy_registry=registry))
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/threads/thread-1/uploads/list",
|
||||
"headers": [],
|
||||
"app": app,
|
||||
"route": SimpleNamespace(path="/api/threads/{thread_id}/uploads/list"),
|
||||
"path_params": {"thread_id": "thread-1"},
|
||||
"auth": auth,
|
||||
}
|
||||
request = Request(scope)
|
||||
request.state.auth = auth
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await enforce_route_policy(request)
|
||||
|
||||
assert getattr(exc_info.value, "status_code", None) == 403
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_enforce_route_policy_runs_owner_policy(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = User(id=uuid4(), email="user@example.com", password_hash="hash")
|
||||
auth = AuthContext(user=user, permissions=["threads:read"])
|
||||
registry = RoutePolicyRegistry(
|
||||
[
|
||||
SimpleNamespace(
|
||||
method="GET",
|
||||
path="/api/threads/{thread_id}/state",
|
||||
spec=RoutePolicySpec(capability="threads:read", policies=("owner:thread",)),
|
||||
matches_request=lambda *_args, **_kwargs: True,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
called: dict[str, object] = {}
|
||||
|
||||
async def fake_owner_check(request: Request, auth_context: AuthContext, *, thread_id: str, require_existing: bool) -> None:
|
||||
called["request"] = request
|
||||
called["auth"] = auth_context
|
||||
called["thread_id"] = thread_id
|
||||
called["require_existing"] = require_existing
|
||||
|
||||
monkeypatch.setattr("app.plugins.auth.injection.route_guard.require_thread_owner", fake_owner_check)
|
||||
|
||||
app = SimpleNamespace(state=SimpleNamespace(auth_route_policy_registry=registry))
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/threads/thread-1/state",
|
||||
"headers": [],
|
||||
"app": app,
|
||||
"route": SimpleNamespace(path="/api/threads/{thread_id}/state"),
|
||||
"path_params": {"thread_id": "thread-1"},
|
||||
"auth": auth,
|
||||
}
|
||||
request = Request(scope)
|
||||
request.state.auth = auth
|
||||
|
||||
await enforce_route_policy(request)
|
||||
|
||||
assert called["thread_id"] == "thread-1"
|
||||
assert called["auth"] is auth
|
||||
assert called["require_existing"] is True
|
||||
86
backend/tests/unittest/test_auth_service.py
Normal file
86
backend/tests/unittest/test_auth_service.py
Normal file
@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.plugins.auth.domain.service import AuthService, AuthServiceError
|
||||
from app.plugins.auth.storage.models import User as UserModel # noqa: F401
|
||||
from store.persistence import MappedBase
|
||||
|
||||
|
||||
async def _make_service(tmp_path):
|
||||
engine = create_async_engine(
|
||||
f"sqlite+aiosqlite:///{tmp_path / 'auth-service.db'}",
|
||||
future=True,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
return engine, AuthService(session_factory)
|
||||
|
||||
|
||||
class TestAuthService:
|
||||
@pytest.mark.anyio
|
||||
async def test_register_and_login_local(self, tmp_path):
|
||||
engine, service = await _make_service(tmp_path)
|
||||
try:
|
||||
created = await service.register("user@example.com", "Str0ng!Pass99")
|
||||
logged_in = await service.login_local("user@example.com", "Str0ng!Pass99")
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
assert created.email == "user@example.com"
|
||||
assert created.password_hash is not None
|
||||
assert logged_in.id == created.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_register_duplicate_email_raises(self, tmp_path):
|
||||
engine, service = await _make_service(tmp_path)
|
||||
try:
|
||||
await service.register("dupe@example.com", "Str0ng!Pass99")
|
||||
with pytest.raises(AuthServiceError) as exc_info:
|
||||
await service.register("dupe@example.com", "An0ther!Pass99")
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
assert exc_info.value.code.value == "email_already_exists"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_initialize_admin_only_once(self, tmp_path):
|
||||
engine, service = await _make_service(tmp_path)
|
||||
try:
|
||||
admin = await service.initialize_admin("admin@example.com", "Str0ng!Pass99")
|
||||
with pytest.raises(AuthServiceError) as exc_info:
|
||||
await service.initialize_admin("other@example.com", "An0ther!Pass99")
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
assert admin.system_role == "admin"
|
||||
assert admin.needs_setup is False
|
||||
assert exc_info.value.code.value == "system_already_initialized"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_change_password_updates_token_version_and_clears_setup(self, tmp_path):
|
||||
engine, service = await _make_service(tmp_path)
|
||||
try:
|
||||
user = await service.register("setup@example.com", "Str0ng!Pass99")
|
||||
user.needs_setup = True
|
||||
updated = await service.change_password(
|
||||
user,
|
||||
current_password="Str0ng!Pass99",
|
||||
new_password="N3wer!Pass99",
|
||||
new_email="final@example.com",
|
||||
)
|
||||
relogged = await service.login_local("final@example.com", "N3wer!Pass99")
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
assert updated.email == "final@example.com"
|
||||
assert updated.needs_setup is False
|
||||
assert updated.token_version == 1
|
||||
assert relogged.id == updated.id
|
||||
@ -16,10 +16,11 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.csrf_middleware import (
|
||||
from app.plugins.auth.domain.config import AuthConfig
|
||||
from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.plugins.auth.domain.jwt import decode_token
|
||||
from app.plugins.auth.runtime.config_state import set_auth_config
|
||||
from app.plugins.auth.security.csrf import (
|
||||
CSRF_COOKIE_NAME,
|
||||
CSRF_HEADER_NAME,
|
||||
CSRFMiddleware,
|
||||
@ -34,28 +35,8 @@ _TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _persistence_engine(tmp_path):
|
||||
"""Initialise a per-test SQLite engine + reset cached provider singletons.
|
||||
|
||||
The auth tests call real HTTP handlers that go through
|
||||
``SQLiteUserRepository`` → ``get_session_factory``. Each test gets
|
||||
a fresh DB plus a clean ``deps._cached_*`` so the cached provider
|
||||
does not hold a dangling reference to the previous test's engine.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.gateway import deps
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path}/auth_types.db"
|
||||
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
asyncio.run(close_engine())
|
||||
"""Per-test auth config fixture placeholder."""
|
||||
yield
|
||||
|
||||
|
||||
def _setup_config():
|
||||
@ -174,7 +155,7 @@ def test_decode_token_invalid_sig_maps_to_token_invalid_code():
|
||||
import jwt as pyjwt
|
||||
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
token = pyjwt.encode(payload, "wrong-key-for-tests-minimum-32-chars", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
@ -197,7 +178,7 @@ def test_decode_token_malformed_maps_to_token_invalid_code():
|
||||
|
||||
def test_login_response_model_has_no_access_token():
|
||||
"""LoginResponse should NOT contain access_token field (RFC-001)."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
from app.plugins.auth.api.schemas import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=604800)
|
||||
d = resp.model_dump()
|
||||
@ -208,7 +189,7 @@ def test_login_response_model_has_no_access_token():
|
||||
|
||||
def test_login_response_model_fields():
|
||||
"""LoginResponse has expires_in and needs_setup."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
from app.plugins.auth.api.schemas import LoginResponse
|
||||
|
||||
fields = set(LoginResponse.model_fields.keys())
|
||||
assert fields == {"expires_in", "needs_setup"}
|
||||
@ -219,7 +200,7 @@ def test_login_response_model_fields():
|
||||
|
||||
def test_auth_config_token_expiry_used_in_login_response():
|
||||
"""LoginResponse.expires_in should come from config.token_expiry_days."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
from app.plugins.auth.api.schemas import LoginResponse
|
||||
|
||||
expected_seconds = 14 * 24 * 3600
|
||||
resp = LoginResponse(expires_in=expected_seconds)
|
||||
@ -231,7 +212,7 @@ def test_auth_config_token_expiry_used_in_login_response():
|
||||
|
||||
def test_user_response_system_role_literal():
|
||||
"""UserResponse.system_role should only accept 'admin' or 'user'."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
from app.plugins.auth.domain.models import UserResponse
|
||||
|
||||
# Valid roles
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
|
||||
@ -243,7 +224,7 @@ def test_user_response_system_role_literal():
|
||||
|
||||
def test_user_response_rejects_invalid_role():
|
||||
"""UserResponse should reject invalid system_role values."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
from app.plugins.auth.domain.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="superadmin")
|
||||
@ -263,7 +244,7 @@ def test_get_current_user_no_cookie_returns_not_authenticated():
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@ -279,7 +260,7 @@ def test_get_current_user_expired_token_returns_token_expired():
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
@ -299,11 +280,11 @@ def test_get_current_user_invalid_token_returns_token_invalid():
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
token = pyjwt.encode(payload, "wrong-secret-key-for-tests-minimum-32", algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@ -319,7 +300,7 @@ def test_get_current_user_malformed_token_returns_token_invalid():
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
|
||||
@ -380,19 +361,19 @@ def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
|
||||
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
import app.plugins.auth.runtime.config_state as cfg
|
||||
from app.plugins.auth.runtime.config_state import reset_auth_config
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
reset_auth_config()
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
reset_auth_config()
|
||||
|
||||
|
||||
# ── CSRF middleware integration (unhappy paths) ──────────────────────
|
||||
@ -485,7 +466,7 @@ def test_csrf_middleware_sets_cookie_on_auth_endpoint():
|
||||
|
||||
def test_user_response_missing_required_fields():
|
||||
"""UserResponse with missing fields → ValidationError."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
from app.plugins.auth.domain.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1") # missing email, system_role
|
||||
@ -496,7 +477,7 @@ def test_user_response_missing_required_fields():
|
||||
|
||||
def test_user_response_empty_string_role_rejected():
|
||||
"""Empty string is not a valid role."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
from app.plugins.auth.domain.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="")
|
||||
@ -514,20 +495,15 @@ def _make_auth_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
def _get_auth_client():
|
||||
"""Get TestClient for auth API contract tests."""
|
||||
return TestClient(_make_auth_app())
|
||||
|
||||
|
||||
def test_api_auth_me_no_cookie_returns_structured_401():
|
||||
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
assert "message" in body["detail"]
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
assert "message" in body["detail"]
|
||||
|
||||
|
||||
def test_api_auth_me_expired_token_returns_structured_401():
|
||||
@ -536,75 +512,70 @@ def test_api_auth_me_expired_token_returns_structured_401():
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_expired"
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_api_auth_me_invalid_sig_returns_structured_401():
|
||||
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
token = pyjwt.encode(payload, "wrong-key-for-tests-minimum-32-chars", algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_invalid"
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_api_login_bad_credentials_returns_structured_401():
|
||||
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "invalid_credentials"
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_api_login_success_no_token_in_body():
|
||||
"""Successful login → response body has expires_in but NOT access_token."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
# Register first
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
# Login
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "expires_in" in body
|
||||
assert "access_token" not in body
|
||||
# Token should be in cookie, not body
|
||||
assert "access_token" in resp.cookies
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "expires_in" in body
|
||||
assert "access_token" not in body
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_api_register_duplicate_returns_structured_400():
|
||||
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = "dup-contract-test@test.com"
|
||||
# First register
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
|
||||
# Duplicate
|
||||
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "AnotherStr0ngPwd!"})
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "email_already_exists"
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
email = "dup-contract-test@test.com"
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
|
||||
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "AnotherStr0ngPwd!"})
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "email_already_exists"
|
||||
|
||||
|
||||
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
|
||||
@ -622,80 +593,80 @@ def _get_set_cookie_headers(resp) -> list[str]:
|
||||
def test_register_http_cookie_httponly_true_secure_false():
|
||||
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("http-cookie"), "password": "Tr0ub4dor3a"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("http-cookie"), "password": "Tr0ub4dor3a"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||
|
||||
|
||||
def test_register_https_cookie_httponly_true_secure_true():
|
||||
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("https-cookie"), "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
assert "max-age" in cookie_header.lower()
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("https-cookie"), "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
assert "max-age" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_login_https_sets_secure_cookie():
|
||||
"""HTTPS login → access_token cookie has secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = _unique_email("https-login")
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
email = _unique_email("https-login")
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "Tr0ub4dor3a"})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_secure_on_https():
|
||||
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-https"), "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" in csrf_header.lower()
|
||||
assert "httponly" not in csrf_header.lower()
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-https"), "password": "Tr0ub4dor3a"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" in csrf_header.lower()
|
||||
assert "httponly" not in csrf_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_not_secure_on_http():
|
||||
"""HTTP register → csrf_token cookie does NOT have secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-http"), "password": "Tr0ub4dor3a"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||
with TestClient(_make_auth_app()) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-http"), "password": "Tr0ub4dor3a"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||
@ -38,6 +38,8 @@ def mock_app_config():
|
||||
|
||||
config = MagicMock()
|
||||
config.models = [model]
|
||||
config.database = None
|
||||
config.checkpointer = None
|
||||
return config
|
||||
|
||||
|
||||
@ -86,10 +88,12 @@ class TestClientInit:
|
||||
def test_custom_config_path(self, mock_app_config):
|
||||
with (
|
||||
patch("deerflow.client.reload_app_config") as mock_reload,
|
||||
patch("store.config.app_config.reload_app_config") as mock_storage_reload,
|
||||
patch("deerflow.client.get_app_config", return_value=mock_app_config),
|
||||
):
|
||||
DeerFlowClient(config_path="/tmp/custom.yaml")
|
||||
mock_reload.assert_called_once_with("/tmp/custom.yaml")
|
||||
mock_storage_reload.assert_called_once_with("/tmp/custom.yaml")
|
||||
|
||||
def test_checkpointer_stored(self, mock_app_config):
|
||||
cp = MagicMock()
|
||||
@ -97,6 +101,59 @@ class TestClientInit:
|
||||
c = DeerFlowClient(checkpointer=cp)
|
||||
assert c._checkpointer is cp
|
||||
|
||||
def test_resolve_checkpointer_config_reads_storage_sqlite(self, mock_app_config):
|
||||
storage_app_config = MagicMock()
|
||||
storage = MagicMock()
|
||||
storage.driver = "sqlite"
|
||||
storage.sqlite_storage_path = "/tmp/deerflow.db"
|
||||
storage_app_config.storage = storage
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient()
|
||||
|
||||
with patch("store.config.app_config.get_app_config", return_value=storage_app_config):
|
||||
resolved = c._resolve_checkpointer_config()
|
||||
|
||||
assert resolved == {
|
||||
"backend": "sqlite",
|
||||
"connection_string": "/tmp/deerflow.db",
|
||||
}
|
||||
|
||||
def test_resolve_checkpointer_config_reads_storage_postgres(self, mock_app_config):
|
||||
storage_app_config = MagicMock()
|
||||
storage = MagicMock()
|
||||
storage.driver = "postgres"
|
||||
storage.username = "user"
|
||||
storage.password = "pass"
|
||||
storage.host = "localhost"
|
||||
storage.port = 5432
|
||||
storage.db_name = "deerflow"
|
||||
storage_app_config.storage = storage
|
||||
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient()
|
||||
|
||||
with patch("store.config.app_config.get_app_config", return_value=storage_app_config):
|
||||
resolved = c._resolve_checkpointer_config()
|
||||
|
||||
assert resolved == {
|
||||
"backend": "postgres",
|
||||
"connection_string": "postgresql://user:pass@localhost:5432/deerflow",
|
||||
}
|
||||
|
||||
def test_resolve_checkpointer_config_rejects_mysql(self, mock_app_config):
|
||||
storage_app_config = MagicMock()
|
||||
storage = MagicMock()
|
||||
storage.driver = "mysql"
|
||||
storage_app_config.storage = storage
|
||||
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient()
|
||||
|
||||
with patch("store.config.app_config.get_app_config", return_value=storage_app_config):
|
||||
with pytest.raises(ValueError, match="does not support a MySQL checkpointer"):
|
||||
c._resolve_checkpointer_config()
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_models / list_skills / get_memory
|
||||
@ -817,7 +874,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._agent_name = "custom-agent"
|
||||
client._available_skills = {"test_skill"}
|
||||
@ -842,7 +899,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||
patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@ -867,7 +924,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@ -886,7 +943,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None),
|
||||
patch.object(client, "_get_active_checkpointer", return_value=None),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@ -1015,7 +1072,7 @@ class TestThreadQueries:
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer):
|
||||
# No internal checkpointer, should fetch from provider
|
||||
result = client.list_threads()
|
||||
|
||||
@ -1069,7 +1126,7 @@ class TestThreadQueries:
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch.object(client, "_get_active_checkpointer", return_value=mock_checkpointer):
|
||||
result = client.get_thread("t99")
|
||||
|
||||
assert result["thread_id"] == "t99"
|
||||
@ -1490,7 +1547,7 @@ class TestUploads:
|
||||
|
||||
class TestArtifacts:
|
||||
def test_get_artifact(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
@ -1506,7 +1563,7 @@ class TestArtifacts:
|
||||
assert "text" in mime
|
||||
|
||||
def test_get_artifact_not_found(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
@ -1522,7 +1579,7 @@ class TestArtifacts:
|
||||
client.get_artifact("t1", "bad/path/file.txt")
|
||||
|
||||
def test_get_artifact_path_traversal(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
@ -1711,7 +1768,7 @@ class TestScenarioFileLifecycle:
|
||||
|
||||
def test_upload_then_read_artifact(self, client):
|
||||
"""Upload a file, simulate agent producing artifact, read it back."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
@ -1859,7 +1916,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config_a)
|
||||
first_agent = client._agent
|
||||
@ -1887,7 +1944,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client._ensure_agent(config)
|
||||
@ -1912,7 +1969,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch.object(client, "_get_active_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client.reset_agent()
|
||||
@ -1970,7 +2027,7 @@ class TestScenarioThreadIsolation:
|
||||
|
||||
def test_artifacts_isolated_per_thread(self, client):
|
||||
"""Artifacts in thread-A are not accessible from thread-B."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
@ -2882,7 +2939,7 @@ class TestUploadDeleteSymlink:
|
||||
class TestArtifactHardening:
|
||||
def test_artifact_directory_rejected(self, client):
|
||||
"""get_artifact rejects paths that resolve to a directory."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
@ -2896,7 +2953,7 @@ class TestArtifactHardening:
|
||||
|
||||
def test_artifact_leading_slash_stripped(self, client):
|
||||
"""Paths with leading slash are handled correctly."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
@ -3015,7 +3072,7 @@ class TestBugArtifactPrefixMatchTooLoose:
|
||||
|
||||
def test_exact_prefix_without_subpath_accepted(self, client):
|
||||
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
@ -262,7 +262,7 @@ class TestFileUploadIntegration:
|
||||
|
||||
# Physically exists
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists()
|
||||
|
||||
@ -473,7 +473,7 @@ class TestArtifactAccess:
|
||||
def test_get_artifact_happy_path(self, e2e_env):
|
||||
"""Write a file to outputs, then read it back via get_artifact()."""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
tid = str(uuid.uuid4())
|
||||
@ -490,7 +490,7 @@ class TestArtifactAccess:
|
||||
def test_get_artifact_nested_path(self, e2e_env):
|
||||
"""Artifacts in subdirectories are accessible."""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.runtime.actor_context import get_effective_user_id
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
tid = str(uuid.uuid4())
|
||||
@ -8,6 +8,7 @@ They are skipped in CI and must be run explicitly:
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@ -16,6 +17,12 @@ from deerflow.client import DeerFlowClient, StreamEvent
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.uploads.manager import PathTraversalError
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"Pydantic serializer warnings:.*field_name='context'.*",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
# Skip entire module in CI or when no config.yaml exists
|
||||
_skip_reason = None
|
||||
if os.environ.get("CI"):
|
||||
392
backend/tests/unittest/test_feedback.py
Normal file
392
backend/tests/unittest/test_feedback.py
Normal file
@ -0,0 +1,392 @@
|
||||
"""Tests for current feedback storage adapters and follow-up association."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.infra.storage import AppRunEventStore, FeedbackStoreAdapter, RunStoreAdapter
|
||||
from store.persistence import MappedBase
|
||||
|
||||
|
||||
async def _make_feedback_repo(tmp_path):
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'test.db'}", future=True)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
class _FeedbackRepoCompat:
|
||||
def __init__(self, session_factory):
|
||||
self._repo = FeedbackStoreAdapter(session_factory)
|
||||
|
||||
async def create(self, **kwargs):
|
||||
return await self._repo.create(
|
||||
run_id=kwargs["run_id"],
|
||||
thread_id=kwargs["thread_id"],
|
||||
rating=kwargs["rating"],
|
||||
owner_id=kwargs.get("owner_id"),
|
||||
user_id=kwargs.get("user_id"),
|
||||
message_id=kwargs.get("message_id"),
|
||||
comment=kwargs.get("comment"),
|
||||
)
|
||||
|
||||
async def get(self, feedback_id):
|
||||
return await self._repo.get(feedback_id)
|
||||
|
||||
async def list_by_run(self, thread_id, run_id, user_id=None, limit=100):
|
||||
rows = await self._repo.list_by_run(thread_id, run_id, user_id=user_id, limit=limit)
|
||||
return rows
|
||||
|
||||
async def list_by_thread(self, thread_id, limit=100):
|
||||
return await self._repo.list_by_thread(thread_id, limit=limit)
|
||||
|
||||
async def delete(self, feedback_id):
|
||||
return await self._repo.delete(feedback_id)
|
||||
|
||||
async def aggregate_by_run(self, thread_id, run_id):
|
||||
return await self._repo.aggregate_by_run(thread_id, run_id)
|
||||
|
||||
async def upsert(self, **kwargs):
|
||||
return await self._repo.upsert(
|
||||
run_id=kwargs["run_id"],
|
||||
thread_id=kwargs["thread_id"],
|
||||
rating=kwargs["rating"],
|
||||
user_id=kwargs.get("user_id"),
|
||||
comment=kwargs.get("comment"),
|
||||
)
|
||||
|
||||
async def delete_by_run(self, *, thread_id, run_id, user_id):
|
||||
return await self._repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)
|
||||
|
||||
async def list_by_thread_grouped(self, thread_id, user_id):
|
||||
return await self._repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
||||
|
||||
return engine, session_factory, _FeedbackRepoCompat(session_factory)
|
||||
|
||||
|
||||
# -- FeedbackRepository --
|
||||
|
||||
|
||||
class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_positive(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
assert record["feedback_id"]
|
||||
assert record["rating"] == 1
|
||||
assert record["run_id"] == "r1"
|
||||
assert record["thread_id"] == "t1"
|
||||
assert "created_at" in record
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_negative_with_comment(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(
|
||||
run_id="r1",
|
||||
thread_id="t1",
|
||||
rating=-1,
|
||||
comment="Response was inaccurate",
|
||||
)
|
||||
assert record["rating"] == -1
|
||||
assert record["comment"] == "Response was inaccurate"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_message_id(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42")
|
||||
assert record["message_id"] == "msg-42"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
assert record["user_id"] == "user-1"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_uses_owner_id_fallback(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="owner-1")
|
||||
assert record["user_id"] == "owner-1"
|
||||
assert record["owner_id"] == "owner-1"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_invalid_rating_zero(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=0)
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_invalid_rating_five(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=5)
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
fetched = await repo.get(created["feedback_id"])
|
||||
assert fetched is not None
|
||||
assert fetched["feedback_id"] == created["feedback_id"]
|
||||
assert fetched["rating"] == 1
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nonexistent(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
assert await repo.get("nonexistent") is None
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
|
||||
results = await repo.list_by_run("t1", "r1", user_id=None)
|
||||
assert len(results) == 2
|
||||
assert all(r["run_id"] == "r1" for r in results)
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run_filters_thread_even_with_same_run_id(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t2", rating=-1, user_id="user-2")
|
||||
results = await repo.list_by_run("t1", "r1", user_id=None)
|
||||
assert len(results) == 1
|
||||
assert results[0]["thread_id"] == "t1"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run_respects_limit(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u3")
|
||||
results = await repo.list_by_run("t1", "r1", user_id=None, limit=2)
|
||||
assert len(results) == 2
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r3", thread_id="t2", rating=1)
|
||||
results = await repo.list_by_thread("t1")
|
||||
assert len(results) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in results)
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_respects_limit(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r3", thread_id="t1", rating=1)
|
||||
results = await repo.list_by_thread("t1", limit=2)
|
||||
assert len(results) == 2
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
deleted = await repo.delete(created["feedback_id"])
|
||||
assert deleted is True
|
||||
assert await repo.get(created["feedback_id"]) is None
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete("nonexistent")
|
||||
assert deleted is False
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_by_run(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 3
|
||||
assert stats["positive"] == 2
|
||||
assert stats["negative"] == 1
|
||||
assert stats["run_id"] == "r1"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_empty(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 0
|
||||
assert stats["positive"] == 0
|
||||
assert stats["negative"] == 0
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_creates_new(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
assert record["rating"] == 1
|
||||
assert record["feedback_id"]
|
||||
assert record["user_id"] == "u1"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_updates_existing(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
|
||||
assert second["feedback_id"] == first["feedback_id"]
|
||||
assert second["rating"] == -1
|
||||
assert second["comment"] == "changed my mind"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_different_users_separate(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
|
||||
assert r1["feedback_id"] != r2["feedback_id"]
|
||||
assert r1["rating"] == 1
|
||||
assert r2["rating"] == -1
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_invalid_rating(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is True
|
||||
results = await repo.list_by_run("t1", "r1", user_id="u1")
|
||||
assert len(results) == 0
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run_nonexistent(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is False
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
|
||||
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert "r1" in grouped
|
||||
assert "r2" in grouped
|
||||
assert "r3" not in grouped
|
||||
assert grouped["r1"]["rating"] == 1
|
||||
assert grouped["r2"]["rating"] == -1
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped_filters_by_user_when_same_run_id_exists(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="u1", comment="mine")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="u2", comment="other")
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert grouped["r1"]["user_id"] == "u1"
|
||||
assert grouped["r1"]["comment"] == "mine"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped_empty(self, tmp_path):
|
||||
engine, _, repo = await _make_feedback_repo(tmp_path)
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert grouped == {}
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# -- Follow-up association --
|
||||
|
||||
|
||||
class TestFollowUpAssociation:
|
||||
@pytest.mark.anyio
|
||||
async def test_run_records_follow_up_via_memory_store(self):
|
||||
"""RunStoreAdapter persists follow_up_to_run_id as a first-class field."""
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False)
|
||||
store = RunStoreAdapter(session_factory)
|
||||
await store.create("r1", thread_id="t1", status="success")
|
||||
await store.create("r2", thread_id="t1", follow_up_to_run_id="r1")
|
||||
run = await store.get("r2")
|
||||
assert run is not None
|
||||
assert run["follow_up_to_run_id"] == "r1"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_human_message_has_follow_up_metadata(self):
|
||||
"""AppRunEventStore preserves follow_up_to_run_id in message metadata."""
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False)
|
||||
event_store = AppRunEventStore(session_factory)
|
||||
await event_store.put_batch([
|
||||
{
|
||||
"thread_id": "t1",
|
||||
"run_id": "r2",
|
||||
"event_type": "human_message",
|
||||
"category": "message",
|
||||
"content": "Tell me more about that",
|
||||
"metadata": {"follow_up_to_run_id": "r1"},
|
||||
}
|
||||
])
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1"
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_follow_up_auto_detection_logic(self):
|
||||
"""Simulate the auto-detection: latest successful run becomes follow_up_to."""
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
session_factory = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False)
|
||||
store = RunStoreAdapter(session_factory)
|
||||
await store.create("r1", thread_id="t1", status="success")
|
||||
await store.create("r2", thread_id="t1", status="error")
|
||||
|
||||
# Auto-detect: list_by_thread returns newest first
|
||||
recent = await store.list_by_thread("t1", limit=1)
|
||||
follow_up = None
|
||||
if recent and recent[0].get("status") == "success":
|
||||
follow_up = recent[0]["run_id"]
|
||||
# r2 (error) is newest, so no follow_up detected
|
||||
assert follow_up is None
|
||||
|
||||
# Now add a successful run
|
||||
await store.create("r3", thread_id="t1", status="success")
|
||||
recent = await store.list_by_thread("t1", limit=1)
|
||||
follow_up = None
|
||||
if recent and recent[0].get("status") == "success":
|
||||
follow_up = recent[0]["run_id"]
|
||||
assert follow_up == "r3"
|
||||
await engine.dispose()
|
||||
281
backend/tests/unittest/test_gateway_services.py
Normal file
281
backend/tests/unittest/test_gateway_services.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""Tests for the current runs service modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.gateway.routers.langgraph.runs import RunCreateRequest, format_sse
|
||||
from app.gateway.services.runs.facade_factory import resolve_agent_factory
|
||||
from app.gateway.services.runs.input.request_adapter import (
|
||||
adapt_create_run_request,
|
||||
adapt_create_stream_request,
|
||||
adapt_create_wait_request,
|
||||
adapt_join_stream_request,
|
||||
adapt_join_wait_request,
|
||||
)
|
||||
from app.gateway.services.runs.input.spec_builder import RunSpecBuilder
|
||||
|
||||
|
||||
def _builder() -> RunSpecBuilder:
|
||||
return RunSpecBuilder()
|
||||
|
||||
|
||||
def _build_runnable_config(
|
||||
thread_id: str,
|
||||
request_config: dict | None,
|
||||
metadata: dict | None,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
context: dict | None = None,
|
||||
):
|
||||
return _builder()._build_runnable_config( # noqa: SLF001 - intentional unit coverage
|
||||
thread_id=thread_id,
|
||||
request_config=request_config,
|
||||
metadata=metadata,
|
||||
assistant_id=assistant_id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
def test_format_sse_basic():
|
||||
frame = format_sse("metadata", {"run_id": "abc"})
|
||||
assert frame.startswith("event: metadata\n")
|
||||
assert "data: " in frame
|
||||
parsed = json.loads(frame.split("data: ")[1].split("\n")[0])
|
||||
assert parsed["run_id"] == "abc"
|
||||
|
||||
|
||||
def test_format_sse_with_event_id():
|
||||
frame = format_sse("metadata", {"run_id": "abc"}, event_id="123-0")
|
||||
assert "id: 123-0" in frame
|
||||
|
||||
|
||||
def test_format_sse_end_event_null():
|
||||
frame = format_sse("end", None)
|
||||
assert "data: null" in frame
|
||||
|
||||
|
||||
def test_format_sse_no_event_id():
|
||||
frame = format_sse("values", {"x": 1})
|
||||
assert "id:" not in frame
|
||||
|
||||
|
||||
def test_normalize_stream_modes_none():
|
||||
assert _builder()._normalize_stream_modes(None) == ["values", "messages"] # noqa: SLF001
|
||||
|
||||
|
||||
def test_normalize_stream_modes_string():
|
||||
assert _builder()._normalize_stream_modes("messages-tuple") == ["messages"] # noqa: SLF001
|
||||
|
||||
|
||||
def test_normalize_stream_modes_list():
|
||||
assert _builder()._normalize_stream_modes(["values", "messages-tuple"]) == ["values", "messages"] # noqa: SLF001
|
||||
|
||||
|
||||
def test_normalize_stream_modes_empty_list():
|
||||
assert _builder()._normalize_stream_modes([]) == [] # noqa: SLF001
|
||||
|
||||
|
||||
def test_normalize_input_none():
|
||||
assert _builder()._normalize_input(None) is None # noqa: SLF001
|
||||
|
||||
|
||||
def test_normalize_input_with_messages():
|
||||
result = _builder()._normalize_input({"messages": [{"role": "user", "content": "hi"}]}) # noqa: SLF001
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].content == "hi"
|
||||
|
||||
|
||||
def test_normalize_input_passthrough():
|
||||
result = _builder()._normalize_input({"custom_key": "value"}) # noqa: SLF001
|
||||
assert result == {"custom_key": "value"}
|
||||
|
||||
|
||||
def test_build_runnable_config_basic():
|
||||
config = _build_runnable_config("thread-1", None, None)
|
||||
assert config["configurable"]["thread_id"] == "thread-1"
|
||||
assert config["recursion_limit"] == 100
|
||||
|
||||
|
||||
def test_build_runnable_config_with_overrides():
|
||||
config = _build_runnable_config(
|
||||
"thread-1",
|
||||
{"configurable": {"model_name": "gpt-4"}, "tags": ["test"]},
|
||||
{"user": "alice"},
|
||||
)
|
||||
assert config["configurable"]["model_name"] == "gpt-4"
|
||||
assert config["tags"] == ["test"]
|
||||
assert config["metadata"]["user"] == "alice"
|
||||
|
||||
|
||||
def test_build_runnable_config_custom_agent_injects_agent_name():
|
||||
config = _build_runnable_config("thread-1", None, None, assistant_id="finalis")
|
||||
assert config["configurable"]["agent_name"] == "finalis"
|
||||
|
||||
|
||||
def test_build_runnable_config_lead_agent_no_agent_name():
|
||||
config = _build_runnable_config("thread-1", None, None, assistant_id="lead_agent")
|
||||
assert "agent_name" not in config["configurable"]
|
||||
|
||||
|
||||
def test_build_runnable_config_none_assistant_id_no_agent_name():
|
||||
config = _build_runnable_config("thread-1", None, None, assistant_id=None)
|
||||
assert "agent_name" not in config["configurable"]
|
||||
|
||||
|
||||
def test_build_runnable_config_explicit_agent_name_not_overwritten():
|
||||
config = _build_runnable_config(
|
||||
"thread-1",
|
||||
{"configurable": {"agent_name": "explicit-agent"}},
|
||||
None,
|
||||
assistant_id="other-agent",
|
||||
)
|
||||
assert config["configurable"]["agent_name"] == "explicit-agent"
|
||||
|
||||
|
||||
def test_resolve_agent_factory_returns_make_lead_agent():
|
||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||
|
||||
assert resolve_agent_factory(None) is make_lead_agent
|
||||
assert resolve_agent_factory("lead_agent") is make_lead_agent
|
||||
assert resolve_agent_factory("finalis") is make_lead_agent
|
||||
assert resolve_agent_factory("custom-agent-123") is make_lead_agent
|
||||
|
||||
|
||||
def test_run_create_request_accepts_context():
|
||||
body = RunCreateRequest(
|
||||
input={"messages": [{"role": "user", "content": "hi"}]},
|
||||
context={
|
||||
"model_name": "deepseek-v3",
|
||||
"thinking_enabled": True,
|
||||
"is_plan_mode": True,
|
||||
"subagent_enabled": True,
|
||||
"thread_id": "some-thread-id",
|
||||
},
|
||||
)
|
||||
assert body.context is not None
|
||||
assert body.context["model_name"] == "deepseek-v3"
|
||||
assert body.context["is_plan_mode"] is True
|
||||
assert body.context["subagent_enabled"] is True
|
||||
|
||||
|
||||
def test_run_create_request_context_defaults_to_none():
|
||||
body = RunCreateRequest(input=None)
|
||||
assert body.context is None
|
||||
|
||||
|
||||
def test_context_merges_into_configurable():
|
||||
config = _build_runnable_config(
|
||||
"thread-1",
|
||||
None,
|
||||
None,
|
||||
context={
|
||||
"model_name": "deepseek-v3",
|
||||
"mode": "ultra",
|
||||
"reasoning_effort": "high",
|
||||
"thinking_enabled": True,
|
||||
"is_plan_mode": True,
|
||||
"subagent_enabled": True,
|
||||
"max_concurrent_subagents": 5,
|
||||
"thread_id": "should-be-ignored",
|
||||
},
|
||||
)
|
||||
assert config["configurable"]["model_name"] == "deepseek-v3"
|
||||
assert config["configurable"]["thinking_enabled"] is True
|
||||
assert config["configurable"]["is_plan_mode"] is True
|
||||
assert config["configurable"]["subagent_enabled"] is True
|
||||
assert config["configurable"]["max_concurrent_subagents"] == 5
|
||||
assert config["configurable"]["reasoning_effort"] == "high"
|
||||
assert config["configurable"]["mode"] == "ultra"
|
||||
assert config["configurable"]["thread_id"] == "thread-1"
|
||||
|
||||
|
||||
def test_context_does_not_override_existing_configurable():
|
||||
config = _build_runnable_config(
|
||||
"thread-1",
|
||||
{"configurable": {"model_name": "gpt-4", "is_plan_mode": False}},
|
||||
None,
|
||||
context={
|
||||
"model_name": "deepseek-v3",
|
||||
"is_plan_mode": True,
|
||||
"subagent_enabled": True,
|
||||
},
|
||||
)
|
||||
assert config["configurable"]["model_name"] == "gpt-4"
|
||||
assert config["configurable"]["is_plan_mode"] is False
|
||||
assert config["configurable"]["subagent_enabled"] is True
|
||||
|
||||
|
||||
def test_build_runnable_config_with_context_wrapper_in_request_config():
|
||||
config = _build_runnable_config(
|
||||
"thread-1",
|
||||
{"context": {"user_id": "u-42", "thread_id": "thread-1"}},
|
||||
None,
|
||||
)
|
||||
assert "context" in config
|
||||
assert config["context"]["user_id"] == "u-42"
|
||||
assert "configurable" not in config
|
||||
assert config["recursion_limit"] == 100
|
||||
|
||||
|
||||
def test_build_runnable_config_context_plus_configurable_prefers_context():
|
||||
config = _build_runnable_config(
|
||||
"thread-1",
|
||||
{
|
||||
"context": {"user_id": "u-42"},
|
||||
"configurable": {"model_name": "gpt-4"},
|
||||
},
|
||||
None,
|
||||
)
|
||||
assert "context" in config
|
||||
assert config["context"]["user_id"] == "u-42"
|
||||
assert "configurable" not in config
|
||||
|
||||
|
||||
def test_build_runnable_config_context_passthrough_other_keys():
|
||||
config = _build_runnable_config(
|
||||
"thread-1",
|
||||
{"context": {"thread_id": "thread-1"}, "tags": ["prod"]},
|
||||
None,
|
||||
)
|
||||
assert config["context"]["thread_id"] == "thread-1"
|
||||
assert "configurable" not in config
|
||||
assert config["tags"] == ["prod"]
|
||||
|
||||
|
||||
def test_build_runnable_config_no_request_config():
|
||||
config = _build_runnable_config("thread-abc", None, None)
|
||||
assert config["configurable"] == {"thread_id": "thread-abc"}
|
||||
assert "context" not in config
|
||||
|
||||
|
||||
def test_request_adapter_create_background():
|
||||
adapted = adapt_create_run_request(thread_id="thread-1", body={"input": {"x": 1}})
|
||||
assert adapted.intent == "create_background"
|
||||
assert adapted.thread_id == "thread-1"
|
||||
assert adapted.run_id is None
|
||||
|
||||
|
||||
def test_request_adapter_create_stream():
|
||||
adapted = adapt_create_stream_request(thread_id=None, body={"input": {"x": 1}})
|
||||
assert adapted.intent == "create_and_stream"
|
||||
assert adapted.thread_id is None
|
||||
assert adapted.is_stateless is True
|
||||
|
||||
|
||||
def test_request_adapter_create_wait():
|
||||
adapted = adapt_create_wait_request(thread_id="thread-1", body={})
|
||||
assert adapted.intent == "create_and_wait"
|
||||
assert adapted.thread_id == "thread-1"
|
||||
|
||||
|
||||
def test_request_adapter_join_stream():
|
||||
adapted = adapt_join_stream_request(thread_id="thread-1", run_id="run-1", headers={"Last-Event-ID": "123"})
|
||||
assert adapted.intent == "join_stream"
|
||||
assert adapted.last_event_id == "123"
|
||||
|
||||
|
||||
def test_request_adapter_join_wait():
|
||||
adapted = adapt_join_wait_request(thread_id="thread-1", run_id="run-1")
|
||||
assert adapted.intent == "join_wait"
|
||||
assert adapted.run_id == "run-1"
|
||||
@ -5,7 +5,6 @@ initialized, password strength validation,
|
||||
and public accessibility (no auth cookie required).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
@ -13,41 +12,32 @@ from fastapi.testclient import TestClient
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-initialize-admin-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from store.config.app_config import AppConfig, set_app_config
|
||||
from store.config.storage_config import StorageConfig
|
||||
from app.plugins.auth.domain.config import AuthConfig
|
||||
from app.plugins.auth.runtime.config_state import set_auth_config
|
||||
|
||||
_TEST_SECRET = "test-secret-key-initialize-admin-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth(tmp_path):
|
||||
"""Fresh SQLite engine + auth config per test."""
|
||||
from app.gateway import deps
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
"""Fresh SQLite app config + auth config per test."""
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
url = f"sqlite+aiosqlite:///{tmp_path}/init_admin.db"
|
||||
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
asyncio.run(close_engine())
|
||||
set_app_config(AppConfig(storage=StorageConfig(driver="sqlite", sqlite_dir=str(tmp_path))))
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(_setup_auth):
|
||||
from app.gateway.app import create_app
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.plugins.auth.domain.config import AuthConfig
|
||||
from app.plugins.auth.runtime.config_state import set_auth_config
|
||||
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
app = create_app()
|
||||
# Do NOT use TestClient as a context manager — that would trigger the
|
||||
# full lifespan which requires config.yaml. The auth endpoints work
|
||||
# without the lifespan (persistence engine is set up by _setup_auth).
|
||||
yield TestClient(app)
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
def _init_payload(**extra):
|
||||
@ -110,11 +100,7 @@ def test_initialize_register_does_not_block_initialization(client):
|
||||
|
||||
def test_initialize_accessible_without_cookie(client):
|
||||
"""No access_token cookie needed for /initialize."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json=_init_payload(),
|
||||
cookies={},
|
||||
)
|
||||
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
assert resp.status_code == 201
|
||||
|
||||
|
||||
@ -152,7 +152,7 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path):
|
||||
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
|
||||
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
from deerflow.runtime import actor_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
@ -312,7 +312,7 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
|
||||
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
|
||||
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
from deerflow.runtime import actor_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
@ -9,19 +9,24 @@ import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||
from app.plugins.auth.domain.config import AuthConfig
|
||||
from app.plugins.auth.domain.jwt import create_access_token, decode_token
|
||||
from app.plugins.auth.security.langgraph import add_owner_filter, authenticate
|
||||
from app.plugins.auth.domain.models import User as AuthUser
|
||||
from app.plugins.auth.runtime.config_state import set_auth_config
|
||||
from app.plugins.auth.storage import DbUserRepository, UserCreate
|
||||
from store.persistence import MappedBase
|
||||
from app.plugins.auth.storage.models import User as UserModel # noqa: F401
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
@ -40,13 +45,40 @@ def _req(cookies=None, method="GET", headers=None):
|
||||
|
||||
|
||||
def _user(user_id=None, token_version=0):
|
||||
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
|
||||
return AuthUser(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
|
||||
|
||||
|
||||
def _mock_provider(user=None):
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(return_value=user)
|
||||
return p
|
||||
async def _attach_auth_session(request, tmp_path, user: AuthUser | None = None):
|
||||
engine = create_async_engine(
|
||||
f"sqlite+aiosqlite:///{tmp_path / 'langgraph-auth.db'}",
|
||||
future=True,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
session = session_factory()
|
||||
if user is not None:
|
||||
repo = DbUserRepository(session)
|
||||
await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
system_role=user.system_role,
|
||||
oauth_provider=user.oauth_provider,
|
||||
oauth_id=user.oauth_id,
|
||||
needs_setup=user.needs_setup,
|
||||
token_version=user.token_version,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
request._auth_session = session
|
||||
return engine, session
|
||||
|
||||
|
||||
# ── @auth.authenticate ───────────────────────────────────────────────────
|
||||
@ -73,77 +105,103 @@ def test_expired_jwt_raises_401():
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_user_not_found_raises_401():
|
||||
@pytest.mark.anyio
|
||||
async def test_user_not_found_raises_401(tmp_path):
|
||||
token = create_access_token("ghost")
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
|
||||
request = _req({"access_token": token})
|
||||
engine, session = await _attach_auth_session(request, tmp_path)
|
||||
try:
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
await authenticate(request)
|
||||
assert exc.value.status_code == 401
|
||||
assert "User not found" in str(exc.value.detail)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def test_token_version_mismatch_raises_401():
|
||||
@pytest.mark.anyio
|
||||
async def test_token_version_mismatch_raises_401(tmp_path):
|
||||
user = _user(token_version=2)
|
||||
token = create_access_token(str(user.id), token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
request = _req({"access_token": token})
|
||||
engine, session = await _attach_auth_session(request, tmp_path, user)
|
||||
try:
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
await authenticate(request)
|
||||
assert exc.value.status_code == 401
|
||||
assert "revoked" in str(exc.value.detail).lower()
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def test_valid_token_returns_user_id():
|
||||
@pytest.mark.anyio
|
||||
async def test_valid_token_returns_user_id(tmp_path):
|
||||
user = _user(token_version=0)
|
||||
token = create_access_token(str(user.id), token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
request = _req({"access_token": token})
|
||||
engine, session = await _attach_auth_session(request, tmp_path, user)
|
||||
try:
|
||||
result = await authenticate(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
def test_valid_token_matching_version():
|
||||
@pytest.mark.anyio
|
||||
async def test_valid_token_matching_version(tmp_path):
|
||||
user = _user(token_version=5)
|
||||
token = create_access_token(str(user.id), token_version=5)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
request = _req({"access_token": token})
|
||||
engine, session = await _attach_auth_session(request, tmp_path, user)
|
||||
try:
|
||||
result = await authenticate(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
# ── @auth.authenticate edge cases ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_provider_exception_propagates():
|
||||
"""Provider raises → should not be swallowed silently."""
|
||||
token = create_access_token("user-1")
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
|
||||
with pytest.raises(RuntimeError, match="DB down"):
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
|
||||
|
||||
def test_jwt_missing_ver_defaults_to_zero():
|
||||
@pytest.mark.anyio
|
||||
async def test_jwt_missing_ver_defaults_to_zero(tmp_path):
|
||||
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
request = _req({"access_token": raw})
|
||||
engine, session = await _attach_auth_session(request, tmp_path, user)
|
||||
try:
|
||||
result = await authenticate(request)
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
assert result == uid
|
||||
|
||||
|
||||
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
|
||||
@pytest.mark.anyio
|
||||
async def test_jwt_missing_ver_rejected_when_user_version_nonzero(tmp_path):
|
||||
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
request = _req({"access_token": raw})
|
||||
engine, session = await _attach_auth_session(request, tmp_path, user)
|
||||
try:
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
await authenticate(request)
|
||||
assert exc.value.status_code == 401
|
||||
finally:
|
||||
await session.close()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def test_wrong_secret_raises_401():
|
||||
@ -223,7 +281,7 @@ def test_filter_with_empty_metadata():
|
||||
def test_shared_jwt_secret():
|
||||
token = create_access_token("user-1", token_version=3)
|
||||
payload = decode_token(token)
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.plugins.auth.domain.errors import TokenError
|
||||
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.sub == "user-1"
|
||||
@ -233,13 +291,13 @@ def test_shared_jwt_secret():
|
||||
def test_langgraph_json_has_auth_path():
|
||||
import json
|
||||
|
||||
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
|
||||
assert "auth" in config
|
||||
assert "langgraph_auth" in config["auth"]["path"]
|
||||
config = json.loads((Path(__file__).resolve().parents[2] / "langgraph.json").read_text())
|
||||
assert "graphs" in config
|
||||
assert "lead_agent" in config["graphs"]
|
||||
|
||||
|
||||
def test_auth_handler_has_both_layers():
|
||||
from app.gateway.langgraph_auth import auth
|
||||
from app.plugins.auth.security.langgraph import auth
|
||||
|
||||
assert auth._authenticate_handler is not None
|
||||
assert len(auth._global_handlers) == 1
|
||||
236
backend/tests/unittest/test_owner_isolation.py
Normal file
236
backend/tests/unittest/test_owner_isolation.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""Cross-user isolation tests for current app-owned storage adapters.
|
||||
|
||||
These tests exercise isolation by binding different ``ActorContext``
|
||||
values around the app-layer storage adapters. The safety property is:
|
||||
|
||||
data written under user A is not visible to user B through the same
|
||||
adapter surface unless a call explicitly opts out with ``user_id=None``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.infra.storage import AppRunEventStore, FeedbackStoreAdapter, RunStoreAdapter, ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||
from deerflow.runtime.actor_context import AUTO, ActorContext, bind_actor_context, reset_actor_context
|
||||
from store.persistence import MappedBase
|
||||
|
||||
|
||||
USER_A = "user-a"
|
||||
USER_B = "user-b"
|
||||
|
||||
|
||||
async def _make_components(tmp_path):
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_path / 'isolation.db'}", future=True)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
thread_store = ThreadMetaStorage(ThreadMetaStoreAdapter(session_factory))
|
||||
return (
|
||||
engine,
|
||||
thread_store,
|
||||
RunStoreAdapter(session_factory),
|
||||
FeedbackStoreAdapter(session_factory),
|
||||
AppRunEventStore(session_factory),
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _as_user(user_id: str):
|
||||
token = bind_actor_context(ActorContext(user_id=user_id))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_actor_context(token)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_thread_meta_cross_user_isolation(tmp_path):
|
||||
engine, thread_store, _, _, _ = await _make_components(tmp_path)
|
||||
try:
|
||||
with _as_user(USER_A):
|
||||
await thread_store.ensure_thread(thread_id="t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await thread_store.ensure_thread(thread_id="t-beta")
|
||||
|
||||
with _as_user(USER_A):
|
||||
assert (await thread_store.get_thread("t-alpha")) is not None
|
||||
assert await thread_store.get_thread("t-beta") is None
|
||||
rows = await thread_store.search_threads()
|
||||
assert [row.thread_id for row in rows] == ["t-alpha"]
|
||||
|
||||
with _as_user(USER_B):
|
||||
assert (await thread_store.get_thread("t-beta")) is not None
|
||||
assert await thread_store.get_thread("t-alpha") is None
|
||||
rows = await thread_store.search_threads()
|
||||
assert [row.thread_id for row in rows] == ["t-beta"]
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_runs_cross_user_isolation(tmp_path):
|
||||
engine, thread_store, run_store, _, _ = await _make_components(tmp_path)
|
||||
try:
|
||||
with _as_user(USER_A):
|
||||
await thread_store.ensure_thread(thread_id="t-alpha")
|
||||
await run_store.create("run-a1", "t-alpha")
|
||||
await run_store.create("run-a2", "t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await thread_store.ensure_thread(thread_id="t-beta")
|
||||
await run_store.create("run-b1", "t-beta")
|
||||
|
||||
with _as_user(USER_A):
|
||||
assert (await run_store.get("run-a1")) is not None
|
||||
assert await run_store.get("run-b1") is None
|
||||
rows = await run_store.list_by_thread("t-alpha")
|
||||
assert {row["run_id"] for row in rows} == {"run-a1", "run-a2"}
|
||||
assert await run_store.list_by_thread("t-beta") == []
|
||||
|
||||
with _as_user(USER_B):
|
||||
assert await run_store.get("run-a1") is None
|
||||
rows = await run_store.list_by_thread("t-beta")
|
||||
assert [row["run_id"] for row in rows] == ["run-b1"]
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_run_events_cross_user_isolation(tmp_path):
|
||||
engine, thread_store, _, _, event_store = await _make_components(tmp_path)
|
||||
try:
|
||||
with _as_user(USER_A):
|
||||
await thread_store.ensure_thread(thread_id="t-alpha")
|
||||
await event_store.put_batch(
|
||||
[
|
||||
{
|
||||
"thread_id": "t-alpha",
|
||||
"run_id": "run-a1",
|
||||
"event_type": "human_message",
|
||||
"category": "message",
|
||||
"content": "User A private question",
|
||||
},
|
||||
{
|
||||
"thread_id": "t-alpha",
|
||||
"run_id": "run-a1",
|
||||
"event_type": "ai_message",
|
||||
"category": "message",
|
||||
"content": "User A private answer",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
with _as_user(USER_B):
|
||||
await thread_store.ensure_thread(thread_id="t-beta")
|
||||
await event_store.put_batch(
|
||||
[
|
||||
{
|
||||
"thread_id": "t-beta",
|
||||
"run_id": "run-b1",
|
||||
"event_type": "human_message",
|
||||
"category": "message",
|
||||
"content": "User B private question",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with _as_user(USER_A):
|
||||
msgs = await event_store.list_messages("t-alpha")
|
||||
contents = [msg["content"] for msg in msgs]
|
||||
assert "User A private question" in contents
|
||||
assert "User A private answer" in contents
|
||||
assert "User B private question" not in contents
|
||||
assert await event_store.list_messages("t-beta") == []
|
||||
assert await event_store.list_events("t-beta", "run-b1") == []
|
||||
assert await event_store.count_messages("t-beta") == 0
|
||||
|
||||
with _as_user(USER_B):
|
||||
msgs = await event_store.list_messages("t-beta")
|
||||
contents = [msg["content"] for msg in msgs]
|
||||
assert "User B private question" in contents
|
||||
assert "User A private question" not in contents
|
||||
assert await event_store.count_messages("t-alpha") == 0
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_feedback_cross_user_isolation(tmp_path):
|
||||
engine, thread_store, _, feedback_store, _ = await _make_components(tmp_path)
|
||||
try:
|
||||
with _as_user(USER_A):
|
||||
await thread_store.ensure_thread(thread_id="t-alpha")
|
||||
a_feedback = await feedback_store.create(
|
||||
run_id="run-a1",
|
||||
thread_id="t-alpha",
|
||||
rating=1,
|
||||
user_id=USER_A,
|
||||
comment="A liked this",
|
||||
)
|
||||
|
||||
with _as_user(USER_B):
|
||||
await thread_store.ensure_thread(thread_id="t-beta")
|
||||
b_feedback = await feedback_store.create(
|
||||
run_id="run-b1",
|
||||
thread_id="t-beta",
|
||||
rating=-1,
|
||||
user_id=USER_B,
|
||||
comment="B disliked this",
|
||||
)
|
||||
|
||||
with _as_user(USER_A):
|
||||
assert (await feedback_store.get(a_feedback["feedback_id"])) is not None
|
||||
assert await feedback_store.get(b_feedback["feedback_id"]) is not None
|
||||
assert await feedback_store.list_by_run("t-beta", "run-b1", user_id=USER_A) == []
|
||||
|
||||
with _as_user(USER_B):
|
||||
assert await feedback_store.list_by_run("t-alpha", "run-a1", user_id=USER_B) == []
|
||||
rows = await feedback_store.list_by_run("t-beta", "run-b1", user_id=USER_B)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["comment"] == "B disliked this"
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_repository_without_context_raises(tmp_path):
|
||||
engine, thread_store, _, _, _ = await _make_components(tmp_path)
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="no actor context is set"):
|
||||
await thread_store.search_threads(user_id=AUTO)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
engine, thread_store, _, _, _ = await _make_components(tmp_path)
|
||||
try:
|
||||
with _as_user(USER_A):
|
||||
await thread_store.ensure_thread(thread_id="t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await thread_store.ensure_thread(thread_id="t-beta")
|
||||
|
||||
rows = await thread_store.search_threads(user_id=None)
|
||||
assert {row.thread_id for row in rows} == {"t-alpha", "t-beta"}
|
||||
assert await thread_store.get_thread("t-alpha", user_id=None) is not None
|
||||
assert await thread_store.get_thread("t-beta", user_id=None) is not None
|
||||
finally:
|
||||
await engine.dispose()
|
||||
66
backend/tests/unittest/test_run_callbacks_builder.py
Normal file
66
backend/tests/unittest/test_run_callbacks_builder.py
Normal file
@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from deerflow.runtime.runs.callbacks.builder import build_run_callbacks
|
||||
from deerflow.runtime.runs.types import RunRecord, RunStatus
|
||||
|
||||
|
||||
def _record() -> RunRecord:
|
||||
return RunRecord(
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
assistant_id=None,
|
||||
status=RunStatus.pending,
|
||||
temporary=False,
|
||||
multitask_strategy="reject",
|
||||
metadata={},
|
||||
created_at="",
|
||||
updated_at="",
|
||||
)
|
||||
|
||||
|
||||
def test_build_run_callbacks_sets_first_human_message_from_string_content():
|
||||
artifacts = build_run_callbacks(
|
||||
record=_record(),
|
||||
graph_input={"messages": [HumanMessage(content="hello world")]},
|
||||
event_store=None,
|
||||
)
|
||||
|
||||
assert artifacts.completion_data().first_human_message == "hello world"
|
||||
|
||||
|
||||
def test_build_run_callbacks_sets_first_human_message_from_content_blocks():
|
||||
artifacts = build_run_callbacks(
|
||||
record=_record(),
|
||||
graph_input={
|
||||
"messages": [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "hello "},
|
||||
{"type": "text", "text": "world"},
|
||||
]
|
||||
)
|
||||
]
|
||||
},
|
||||
event_store=None,
|
||||
)
|
||||
|
||||
assert artifacts.completion_data().first_human_message == "hello world"
|
||||
|
||||
|
||||
def test_build_run_callbacks_sets_first_human_message_from_dict_payload():
|
||||
artifacts = build_run_callbacks(
|
||||
record=_record(),
|
||||
graph_input={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "hi from dict"}],
|
||||
}
|
||||
]
|
||||
},
|
||||
event_store=None,
|
||||
)
|
||||
|
||||
assert artifacts.completion_data().first_human_message == "hi from dict"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user