From ceeccabc989a4c0d6c259d9552b1d32475656bf5 Mon Sep 17 00:00:00 2001 From: greatmengqi Date: Wed, 8 Apr 2026 11:49:24 +0800 Subject: [PATCH] refactor(auth): migrate user repository to SQLAlchemy ORM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move the users table into the shared persistence engine so auth matches the pattern of threads_meta, runs, run_events, and feedback — one engine, one session factory, one schema init codepath. New files --------- - persistence/user/__init__.py, persistence/user/model.py: UserRow ORM class with partial unique index on (oauth_provider, oauth_id) - Registered in persistence/models/__init__.py so Base.metadata.create_all() picks it up Modified -------- - auth/repositories/sqlite.py: rewritten as async SQLAlchemy, identical constructor pattern to the other four repositories (def __init__(self, session_factory) + self._sf = session_factory) - auth/config.py: drop users_db_path field — storage is configured through config.database like every other table - deps.py/get_local_provider: construct SQLiteUserRepository with the shared session factory, fail fast if engine is not initialised - tests/test_auth.py: rewrite test_sqlite_round_trip_new_fields to use the shared engine (init_engine + close_engine in a tempdir) - tests/test_auth_type_system.py: add per-test autouse fixture that spins up a scratch engine and resets deps._cached_* singletons --- backend/app/gateway/auth/config.py | 12 +- .../app/gateway/auth/repositories/sqlite.py | 266 ++++++------------ backend/app/gateway/deps.py | 12 +- .../deerflow/persistence/models/__init__.py | 4 +- .../deerflow/persistence/user/__init__.py | 12 + .../deerflow/persistence/user/model.py | 59 ++++ backend/tests/test_auth.py | 79 +++--- backend/tests/test_auth_type_system.py | 26 ++ 8 files changed, 254 insertions(+), 216 deletions(-) create mode 100644 backend/packages/harness/deerflow/persistence/user/__init__.py create mode 100644 backend/packages/harness/deerflow/persistence/user/model.py diff --git a/backend/app/gateway/auth/config.py b/backend/app/gateway/auth/config.py index ca10acc20..01f0870fd 100644 --- a/backend/app/gateway/auth/config.py +++ b/backend/app/gateway/auth/config.py @@ -13,17 +13,19 @@ logger = logging.getLogger(__name__) class AuthConfig(BaseModel): - """JWT and auth-related configuration. Parsed once at startup.""" + """JWT and auth-related configuration. Parsed once at startup. + + Note: the ``users`` table now lives in the shared persistence + database managed by ``deerflow.persistence.engine``. The old + ``users_db_path`` config key has been removed — user storage is + configured through ``config.database`` like every other table. + """ jwt_secret: str = Field( ..., description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.", ) token_expiry_days: int = Field(default=7, ge=1, le=30) - users_db_path: str | None = Field( - default=None, - description="Path to users SQLite DB. Defaults to .deer-flow/users.db", - ) oauth_github_client_id: str | None = Field(default=None) oauth_github_client_secret: str | None = Field(default=None) diff --git a/backend/app/gateway/auth/repositories/sqlite.py b/backend/app/gateway/auth/repositories/sqlite.py index 93b768cfe..bbe3712ff 100644 --- a/backend/app/gateway/auth/repositories/sqlite.py +++ b/backend/app/gateway/auth/repositories/sqlite.py @@ -1,196 +1,116 @@ -"""SQLite implementation of UserRepository.""" +"""SQLAlchemy-backed UserRepository implementation. -import asyncio -import sqlite3 -from contextlib import contextmanager -from datetime import UTC, datetime -from pathlib import Path -from typing import Any +Uses the shared async session factory from +``deerflow.persistence.engine`` — the ``users`` table lives in the +same database as ``threads_meta``, ``runs``, ``run_events``, and +``feedback``. + +Constructor takes the session factory directly (same pattern as the +other four repositories in ``deerflow.persistence.*``). Callers +construct this after ``init_engine_from_config()`` has run. +""" + +from __future__ import annotations + +from datetime import UTC from uuid import UUID -from app.gateway.auth.config import get_auth_config +from sqlalchemy import func, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + from app.gateway.auth.models import User from app.gateway.auth.repositories.base import UserRepository - -_resolved_db_path: Path | None = None -_table_initialized: bool = False - - -def _get_users_db_path() -> Path: - """Get the users database path (resolved and cached once).""" - global _resolved_db_path - if _resolved_db_path is not None: - return _resolved_db_path - config = get_auth_config() - if config.users_db_path: - _resolved_db_path = Path(config.users_db_path) - else: - _resolved_db_path = Path(".deer-flow/users.db") - _resolved_db_path.parent.mkdir(parents=True, exist_ok=True) - return _resolved_db_path - - -def _get_connection() -> sqlite3.Connection: - """Get a SQLite connection for the users database.""" - db_path = _get_users_db_path() - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - return conn - - -def _init_users_table(conn: sqlite3.Connection) -> None: - """Initialize the users table if it doesn't exist.""" - conn.execute("PRAGMA journal_mode=WAL") - conn.execute( - """ - CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, - email TEXT UNIQUE NOT NULL, - password_hash TEXT, - system_role TEXT NOT NULL DEFAULT 'user', - created_at REAL NOT NULL, - oauth_provider TEXT, - oauth_id TEXT, - needs_setup INTEGER NOT NULL DEFAULT 0, - token_version INTEGER NOT NULL DEFAULT 0 - ) - """ - ) - # Add unique constraint for OAuth identity to prevent duplicate social logins - conn.execute( - """ - CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity - ON users(oauth_provider, oauth_id) - WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL - """ - ) - conn.commit() - - -@contextmanager -def _get_users_conn(): - """Context manager for users database connection.""" - global _table_initialized - conn = _get_connection() - try: - if not _table_initialized: - _init_users_table(conn) - _table_initialized = True - yield conn - finally: - conn.close() +from deerflow.persistence.user.model import UserRow class SQLiteUserRepository(UserRepository): - """SQLite implementation of UserRepository.""" + """Async user repository backed by the shared SQLAlchemy engine.""" + + def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: + self._sf = session_factory + + # ── Converters ──────────────────────────────────────────────────── + + @staticmethod + def _row_to_user(row: UserRow) -> User: + return User( + id=UUID(row.id), + email=row.email, + password_hash=row.password_hash, + system_role=row.system_role, # type: ignore[arg-type] + # SQLite loses tzinfo on read; reattach UTC so downstream + # code can compare timestamps reliably. + created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC), + oauth_provider=row.oauth_provider, + oauth_id=row.oauth_id, + needs_setup=row.needs_setup, + token_version=row.token_version, + ) + + @staticmethod + def _user_to_row(user: User) -> UserRow: + return UserRow( + id=str(user.id), + email=user.email, + password_hash=user.password_hash, + system_role=user.system_role, + created_at=user.created_at, + oauth_provider=user.oauth_provider, + oauth_id=user.oauth_id, + needs_setup=user.needs_setup, + token_version=user.token_version, + ) + + # ── CRUD ────────────────────────────────────────────────────────── async def create_user(self, user: User) -> User: - """Create a new user in SQLite.""" - return await asyncio.to_thread(self._create_user_sync, user) - - def _create_user_sync(self, user: User) -> User: - """Synchronous user creation (runs in thread pool).""" - with _get_users_conn() as conn: + """Insert a new user. Raises ``ValueError`` on duplicate email.""" + row = self._user_to_row(user) + async with self._sf() as session: + session.add(row) try: - conn.execute( - """ - INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - str(user.id), - user.email, - user.password_hash, - user.system_role, - datetime.now(UTC).timestamp(), - user.oauth_provider, - user.oauth_id, - int(user.needs_setup), - user.token_version, - ), - ) - conn.commit() - except sqlite3.IntegrityError as e: - if "UNIQUE constraint failed: users.email" in str(e): - raise ValueError(f"Email already registered: {user.email}") from e - raise + await session.commit() + except IntegrityError as exc: + await session.rollback() + raise ValueError(f"Email already registered: {user.email}") from exc return user async def get_user_by_id(self, user_id: str) -> User | None: - """Get user by ID from SQLite.""" - return await asyncio.to_thread(self._get_user_by_id_sync, user_id) - - def _get_user_by_id_sync(self, user_id: str) -> User | None: - """Synchronous get by ID (runs in thread pool).""" - with _get_users_conn() as conn: - cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)) - row = cursor.fetchone() - if row is None: - return None - return self._row_to_user(dict(row)) + async with self._sf() as session: + row = await session.get(UserRow, user_id) + return self._row_to_user(row) if row is not None else None async def get_user_by_email(self, email: str) -> User | None: - """Get user by email from SQLite.""" - return await asyncio.to_thread(self._get_user_by_email_sync, email) - - def _get_user_by_email_sync(self, email: str) -> User | None: - """Synchronous get by email (runs in thread pool).""" - with _get_users_conn() as conn: - cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,)) - row = cursor.fetchone() - if row is None: - return None - return self._row_to_user(dict(row)) + stmt = select(UserRow).where(UserRow.email == email) + async with self._sf() as session: + result = await session.execute(stmt) + row = result.scalar_one_or_none() + return self._row_to_user(row) if row is not None else None async def update_user(self, user: User) -> User: - """Update an existing user in SQLite.""" - return await asyncio.to_thread(self._update_user_sync, user) - - def _update_user_sync(self, user: User) -> User: - with _get_users_conn() as conn: - conn.execute( - "UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?", - (user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)), - ) - conn.commit() + async with self._sf() as session: + row = await session.get(UserRow, str(user.id)) + if row is None: + return user + row.email = user.email + row.password_hash = user.password_hash + row.system_role = user.system_role + row.oauth_provider = user.oauth_provider + row.oauth_id = user.oauth_id + row.needs_setup = user.needs_setup + row.token_version = user.token_version + await session.commit() return user async def count_users(self) -> int: - """Return total number of registered users.""" - return await asyncio.to_thread(self._count_users_sync) - - def _count_users_sync(self) -> int: - with _get_users_conn() as conn: - cursor = conn.execute("SELECT COUNT(*) FROM users") - return cursor.fetchone()[0] + stmt = select(func.count()).select_from(UserRow) + async with self._sf() as session: + return await session.scalar(stmt) or 0 async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: - """Get user by OAuth provider and ID from SQLite.""" - return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id) - - def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None: - """Synchronous get by OAuth (runs in thread pool).""" - with _get_users_conn() as conn: - cursor = conn.execute( - "SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?", - (provider, oauth_id), - ) - row = cursor.fetchone() - if row is None: - return None - return self._row_to_user(dict(row)) - - @staticmethod - def _row_to_user(row: dict[str, Any]) -> User: - """Convert a database row to a User model.""" - return User( - id=UUID(row["id"]), - email=row["email"], - password_hash=row["password_hash"], - system_role=row["system_role"], - created_at=datetime.fromtimestamp(row["created_at"], tz=UTC), - oauth_provider=row.get("oauth_provider"), - oauth_id=row.get("oauth_id"), - needs_setup=bool(row["needs_setup"]), - token_version=int(row["token_version"]), - ) + stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id) + async with self._sf() as session: + result = await session.execute(stmt) + row = result.scalar_one_or_none() + return self._row_to_user(row) if row is not None else None diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index b6fa9c975..5ea7f6751 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -142,12 +142,20 @@ _cached_repo: SQLiteUserRepository | None = None def get_local_provider() -> LocalAuthProvider: - """Get or create the cached LocalAuthProvider singleton.""" + """Get or create the cached LocalAuthProvider singleton. + + Must be called after ``init_engine_from_config()`` — the shared + session factory is required to construct the user repository. + """ global _cached_local_provider, _cached_repo if _cached_repo is None: from app.gateway.auth.repositories.sqlite import SQLiteUserRepository + from deerflow.persistence.engine import get_session_factory - _cached_repo = SQLiteUserRepository() + sf = get_session_factory() + if sf is None: + raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table") + _cached_repo = SQLiteUserRepository(sf) if _cached_local_provider is None: from app.gateway.auth.local_provider import LocalAuthProvider diff --git a/backend/packages/harness/deerflow/persistence/models/__init__.py b/backend/packages/harness/deerflow/persistence/models/__init__.py index 659ac07f9..ab29a3536 100644 --- a/backend/packages/harness/deerflow/persistence/models/__init__.py +++ b/backend/packages/harness/deerflow/persistence/models/__init__.py @@ -7,6 +7,7 @@ The actual ORM classes have moved to entity-specific subpackages: - ``deerflow.persistence.thread_meta`` - ``deerflow.persistence.run`` - ``deerflow.persistence.feedback`` +- ``deerflow.persistence.user`` ``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because its storage implementation lives in ``deerflow.runtime.events.store.db`` and @@ -17,5 +18,6 @@ from deerflow.persistence.feedback.model import FeedbackRow from deerflow.persistence.models.run_event import RunEventRow from deerflow.persistence.run.model import RunRow from deerflow.persistence.thread_meta.model import ThreadMetaRow +from deerflow.persistence.user.model import UserRow -__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow"] +__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"] diff --git a/backend/packages/harness/deerflow/persistence/user/__init__.py b/backend/packages/harness/deerflow/persistence/user/__init__.py new file mode 100644 index 000000000..a60eeef2c --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/user/__init__.py @@ -0,0 +1,12 @@ +"""User storage subpackage. + +Holds the ORM model for the ``users`` table. The concrete repository +implementation (``SQLiteUserRepository``) lives in the app layer +(``app.gateway.auth.repositories.sqlite``) because it converts +between the ORM row and the auth module's pydantic ``User`` class. +This keeps the harness package free of any dependency on app code. +""" + +from deerflow.persistence.user.model import UserRow + +__all__ = ["UserRow"] diff --git a/backend/packages/harness/deerflow/persistence/user/model.py b/backend/packages/harness/deerflow/persistence/user/model.py new file mode 100644 index 000000000..130d4bfcb --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/user/model.py @@ -0,0 +1,59 @@ +"""ORM model for the users table. + +Lives in the harness persistence package so it is picked up by +``Base.metadata.create_all()`` alongside ``threads_meta``, ``runs``, +``run_events``, and ``feedback``. Using the shared engine means: + +- One SQLite/Postgres database, one connection pool +- One schema initialisation codepath +- Consistent async sessions across auth and persistence reads +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import Boolean, DateTime, Index, String, text +from sqlalchemy.orm import Mapped, mapped_column + +from deerflow.persistence.base import Base + + +class UserRow(Base): + __tablename__ = "users" + + # UUIDs are stored as 36-char strings for cross-backend portability. + id: Mapped[str] = mapped_column(String(36), primary_key=True) + + email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True) + password_hash: Mapped[str | None] = mapped_column(String(128), nullable=True) + + # "admin" | "user" — kept as plain string to avoid ALTER TABLE pain + # when new roles are introduced. + system_role: Mapped[str] = mapped_column(String(16), nullable=False, default="user") + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + ) + + # OAuth linkage (optional). A partial unique index enforces one + # account per (provider, oauth_id) pair, leaving NULL/NULL rows + # unconstrained so plain password accounts can coexist. + oauth_provider: Mapped[str | None] = mapped_column(String(32), nullable=True) + oauth_id: Mapped[str | None] = mapped_column(String(128), nullable=True) + + # Auth lifecycle flags + needs_setup: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + token_version: Mapped[int] = mapped_column(nullable=False, default=0) + + __table_args__ = ( + Index( + "idx_users_oauth_identity", + "oauth_provider", + "oauth_id", + unique=True, + sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"), + ), + ) diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index d73a6925f..ca03528bc 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -262,47 +262,56 @@ def test_user_model_needs_setup_true(): def test_sqlite_round_trip_new_fields(): - """needs_setup and token_version survive create → read round-trip.""" + """needs_setup and token_version survive create → read round-trip. + + Uses the shared persistence engine (same one threads_meta, runs, + run_events, and feedback use). The old separate .deer-flow/users.db + file is gone. + """ import asyncio - import os import tempfile - from pathlib import Path - from app.gateway.auth.repositories import sqlite as sqlite_mod + from app.gateway.auth.repositories.sqlite import SQLiteUserRepository - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test_users.db") - old_path = sqlite_mod._resolved_db_path - old_init = sqlite_mod._table_initialized - sqlite_mod._resolved_db_path = Path(db_path) - sqlite_mod._table_initialized = False - try: - repo = sqlite_mod.SQLiteUserRepository() - user = User( - email="setup@test.com", - password_hash="fakehash", - system_role="admin", - needs_setup=True, - token_version=3, - ) - created = asyncio.run(repo.create_user(user)) - assert created.needs_setup is True - assert created.token_version == 3 + async def _run() -> None: + from deerflow.persistence.engine import ( + close_engine, + get_session_factory, + init_engine, + ) - fetched = asyncio.run(repo.get_user_by_email("setup@test.com")) - assert fetched is not None - assert fetched.needs_setup is True - assert fetched.token_version == 3 + with tempfile.TemporaryDirectory() as tmpdir: + url = f"sqlite+aiosqlite:///{tmpdir}/scratch.db" + await init_engine("sqlite", url=url, sqlite_dir=tmpdir) + try: + repo = SQLiteUserRepository(get_session_factory()) + user = User( + email="setup@test.com", + password_hash="fakehash", + system_role="admin", + needs_setup=True, + token_version=3, + ) + created = await repo.create_user(user) + assert created.needs_setup is True + assert created.token_version == 3 - fetched.needs_setup = False - fetched.token_version = 4 - asyncio.run(repo.update_user(fetched)) - refetched = asyncio.run(repo.get_user_by_id(str(fetched.id))) - assert refetched.needs_setup is False - assert refetched.token_version == 4 - finally: - sqlite_mod._resolved_db_path = old_path - sqlite_mod._table_initialized = old_init + fetched = await repo.get_user_by_email("setup@test.com") + assert fetched is not None + assert fetched.needs_setup is True + assert fetched.token_version == 3 + + fetched.needs_setup = False + fetched.token_version = 4 + await repo.update_user(fetched) + refetched = await repo.get_user_by_id(str(fetched.id)) + assert refetched is not None + assert refetched.needs_setup is False + assert refetched.token_version == 4 + finally: + await close_engine() + + asyncio.run(_run()) # ── Token Versioning ─────────────────────────────────────────────────────── diff --git a/backend/tests/test_auth_type_system.py b/backend/tests/test_auth_type_system.py index 18b4542d0..81b1d5523 100644 --- a/backend/tests/test_auth_type_system.py +++ b/backend/tests/test_auth_type_system.py @@ -32,6 +32,32 @@ from app.gateway.csrf_middleware import ( _TEST_SECRET = "test-secret-for-auth-type-system-tests-min32" +@pytest.fixture(autouse=True) +def _persistence_engine(tmp_path): + """Initialise a per-test SQLite engine + reset cached provider singletons. + + The auth tests call real HTTP handlers that go through + ``SQLiteUserRepository`` → ``get_session_factory``. Each test gets + a fresh DB plus a clean ``deps._cached_*`` so the cached provider + does not hold a dangling reference to the previous test's engine. + """ + import asyncio + + from app.gateway import deps + from deerflow.persistence.engine import close_engine, init_engine + + url = f"sqlite+aiosqlite:///{tmp_path}/auth_types.db" + asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))) + deps._cached_local_provider = None + deps._cached_repo = None + try: + yield + finally: + deps._cached_local_provider = None + deps._cached_repo = None + asyncio.run(close_engine()) + + def _setup_config(): set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))