From 14892e14639229790b615ff3b81e06b84567d8e1 Mon Sep 17 00:00:00 2001 From: rayhpeng Date: Wed, 22 Apr 2026 11:26:19 +0800 Subject: [PATCH] refactor(gateway): remove old auth system and middleware Remove deprecated authentication and authorization modules: - app/gateway/auth/ - auth providers, JWT, password handling, repositories - app/gateway/auth_middleware.py - authentication middleware - app/gateway/authz.py - authorization module - app/gateway/csrf_middleware.py - CSRF protection middleware - app/gateway/deps.py - old dependency injection - app/gateway/langgraph_auth.py - LangGraph authentication - app/gateway/routers/auth.py - auth API endpoints - app/gateway/routers/assistants_compat.py - assistants compatibility layer These are replaced by the new auth system in packages/storage/. Co-Authored-By: Claude Opus 4.5 --- backend/app/gateway/auth/__init__.py | 42 -- backend/app/gateway/auth/config.py | 57 --- backend/app/gateway/auth/credential_file.py | 48 -- backend/app/gateway/auth/errors.py | 45 -- backend/app/gateway/auth/jwt.py | 55 --- backend/app/gateway/auth/local_provider.py | 91 ---- backend/app/gateway/auth/models.py | 41 -- backend/app/gateway/auth/password.py | 33 -- backend/app/gateway/auth/providers.py | 24 - .../app/gateway/auth/repositories/__init__.py | 0 backend/app/gateway/auth/repositories/base.py | 102 ---- .../app/gateway/auth/repositories/sqlite.py | 127 ----- backend/app/gateway/auth/reset_admin.py | 91 ---- backend/app/gateway/auth_middleware.py | 118 ----- backend/app/gateway/authz.py | 262 ---------- backend/app/gateway/csrf_middleware.py | 113 ----- backend/app/gateway/deps.py | 234 --------- backend/app/gateway/langgraph_auth.py | 106 ---- .../app/gateway/routers/assistants_compat.py | 149 ------ backend/app/gateway/routers/auth.py | 459 ------------------ 20 files changed, 2197 deletions(-) delete mode 100644 backend/app/gateway/auth/__init__.py delete mode 100644 backend/app/gateway/auth/config.py delete mode 100644 backend/app/gateway/auth/credential_file.py delete mode 100644 backend/app/gateway/auth/errors.py delete mode 100644 backend/app/gateway/auth/jwt.py delete mode 100644 backend/app/gateway/auth/local_provider.py delete mode 100644 backend/app/gateway/auth/models.py delete mode 100644 backend/app/gateway/auth/password.py delete mode 100644 backend/app/gateway/auth/providers.py delete mode 100644 backend/app/gateway/auth/repositories/__init__.py delete mode 100644 backend/app/gateway/auth/repositories/base.py delete mode 100644 backend/app/gateway/auth/repositories/sqlite.py delete mode 100644 backend/app/gateway/auth/reset_admin.py delete mode 100644 backend/app/gateway/auth_middleware.py delete mode 100644 backend/app/gateway/authz.py delete mode 100644 backend/app/gateway/csrf_middleware.py delete mode 100644 backend/app/gateway/deps.py delete mode 100644 backend/app/gateway/langgraph_auth.py delete mode 100644 backend/app/gateway/routers/assistants_compat.py delete mode 100644 backend/app/gateway/routers/auth.py diff --git a/backend/app/gateway/auth/__init__.py b/backend/app/gateway/auth/__init__.py deleted file mode 100644 index 4e9b71c42..000000000 --- a/backend/app/gateway/auth/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Authentication module for DeerFlow. - -This module provides: -- JWT-based authentication -- Provider Factory pattern for extensible auth methods -- UserRepository interface for storage backends (SQLite) -""" - -from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config -from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError -from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token -from app.gateway.auth.local_provider import LocalAuthProvider -from app.gateway.auth.models import User, UserResponse -from app.gateway.auth.password import hash_password, verify_password -from app.gateway.auth.providers import AuthProvider -from app.gateway.auth.repositories.base import UserRepository - -__all__ = [ - # Config - "AuthConfig", - "get_auth_config", - "set_auth_config", - # Errors - "AuthErrorCode", - "AuthErrorResponse", - "TokenError", - # JWT - "TokenPayload", - "create_access_token", - "decode_token", - # Password - "hash_password", - "verify_password", - # Models - "User", - "UserResponse", - # Providers - "AuthProvider", - "LocalAuthProvider", - # Repository - "UserRepository", -] diff --git a/backend/app/gateway/auth/config.py b/backend/app/gateway/auth/config.py deleted file mode 100644 index 01f0870fd..000000000 --- a/backend/app/gateway/auth/config.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Authentication configuration for DeerFlow.""" - -import logging -import os -import secrets - -from dotenv import load_dotenv -from pydantic import BaseModel, Field - -load_dotenv() - -logger = logging.getLogger(__name__) - - -class AuthConfig(BaseModel): - """JWT and auth-related configuration. Parsed once at startup. - - Note: the ``users`` table now lives in the shared persistence - database managed by ``deerflow.persistence.engine``. The old - ``users_db_path`` config key has been removed — user storage is - configured through ``config.database`` like every other table. - """ - - jwt_secret: str = Field( - ..., - description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.", - ) - token_expiry_days: int = Field(default=7, ge=1, le=30) - oauth_github_client_id: str | None = Field(default=None) - oauth_github_client_secret: str | None = Field(default=None) - - -_auth_config: AuthConfig | None = None - - -def get_auth_config() -> AuthConfig: - """Get the global AuthConfig instance. Parses from env on first call.""" - global _auth_config - if _auth_config is None: - jwt_secret = os.environ.get("AUTH_JWT_SECRET") - if not jwt_secret: - jwt_secret = secrets.token_urlsafe(32) - os.environ["AUTH_JWT_SECRET"] = jwt_secret - logger.warning( - "⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. " - "Sessions will be invalidated on restart. " - "For production, add AUTH_JWT_SECRET to your .env file: " - 'python -c "import secrets; print(secrets.token_urlsafe(32))"' - ) - _auth_config = AuthConfig(jwt_secret=jwt_secret) - return _auth_config - - -def set_auth_config(config: AuthConfig) -> None: - """Set the global AuthConfig instance (for testing).""" - global _auth_config - _auth_config = config diff --git a/backend/app/gateway/auth/credential_file.py b/backend/app/gateway/auth/credential_file.py deleted file mode 100644 index 100ca3b04..000000000 --- a/backend/app/gateway/auth/credential_file.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Write initial admin credentials to a restricted file instead of logs. - -Logging secrets to stdout/stderr is a well-known CodeQL finding -(py/clear-text-logging-sensitive-data) — in production those logs -get collected into ELK/Splunk/etc and become a secret sprawl -source. This helper writes the credential to a 0600 file that only -the process user can read, and returns the path so the caller can -log **the path** (not the password) for the operator to pick up. -""" - -from __future__ import annotations - -import os -from pathlib import Path - -from deerflow.config.paths import get_paths - -_CREDENTIAL_FILENAME = "admin_initial_credentials.txt" - - -def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path: - """Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``. - - The file is created **atomically** with mode 0600 via ``os.open`` - so the password is never world-readable, even for the single syscall - window between ``write_text`` and ``chmod``. - - ``label`` distinguishes "initial" (fresh creation) from "reset" - (password reset) in the file header so an operator picking up the - file after a restart can tell which event produced it. - - Returns the absolute :class:`Path` to the file. - """ - target = get_paths().base_dir / _CREDENTIAL_FILENAME - target.parent.mkdir(parents=True, exist_ok=True) - - content = ( - f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n" - ) - - # Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the - # reset-password path can rewrite an existing file without a - # separate unlink-then-create dance. - fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) - with os.fdopen(fd, "w", encoding="utf-8") as fh: - fh.write(content) - - return target.resolve() diff --git a/backend/app/gateway/auth/errors.py b/backend/app/gateway/auth/errors.py deleted file mode 100644 index b5899ebd8..000000000 --- a/backend/app/gateway/auth/errors.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Typed error definitions for auth module. - -AuthErrorCode: exhaustive enum of all auth failure conditions. -TokenError: exhaustive enum of JWT decode failures. -AuthErrorResponse: structured error payload for HTTP responses. -""" - -from enum import StrEnum - -from pydantic import BaseModel - - -class AuthErrorCode(StrEnum): - """Exhaustive list of auth error conditions.""" - - INVALID_CREDENTIALS = "invalid_credentials" - TOKEN_EXPIRED = "token_expired" - TOKEN_INVALID = "token_invalid" - USER_NOT_FOUND = "user_not_found" - EMAIL_ALREADY_EXISTS = "email_already_exists" - PROVIDER_NOT_FOUND = "provider_not_found" - NOT_AUTHENTICATED = "not_authenticated" - SYSTEM_ALREADY_INITIALIZED = "system_already_initialized" - - -class TokenError(StrEnum): - """Exhaustive list of JWT decode failure reasons.""" - - EXPIRED = "expired" - INVALID_SIGNATURE = "invalid_signature" - MALFORMED = "malformed" - - -class AuthErrorResponse(BaseModel): - """Structured error response — replaces bare `detail` strings.""" - - code: AuthErrorCode - message: str - - -def token_error_to_code(err: TokenError) -> AuthErrorCode: - """Map TokenError to AuthErrorCode — single source of truth.""" - if err == TokenError.EXPIRED: - return AuthErrorCode.TOKEN_EXPIRED - return AuthErrorCode.TOKEN_INVALID diff --git a/backend/app/gateway/auth/jwt.py b/backend/app/gateway/auth/jwt.py deleted file mode 100644 index 3853692b7..000000000 --- a/backend/app/gateway/auth/jwt.py +++ /dev/null @@ -1,55 +0,0 @@ -"""JWT token creation and verification.""" - -from datetime import UTC, datetime, timedelta - -import jwt -from pydantic import BaseModel - -from app.gateway.auth.config import get_auth_config -from app.gateway.auth.errors import TokenError - - -class TokenPayload(BaseModel): - """JWT token payload.""" - - sub: str # user_id - exp: datetime - iat: datetime | None = None - ver: int = 0 # token_version — must match User.token_version - - -def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str: - """Create a JWT access token. - - Args: - user_id: The user's UUID as string - expires_delta: Optional custom expiry, defaults to 7 days - token_version: User's current token_version for invalidation - - Returns: - Encoded JWT string - """ - config = get_auth_config() - expiry = expires_delta or timedelta(days=config.token_expiry_days) - - now = datetime.now(UTC) - payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version} - return jwt.encode(payload, config.jwt_secret, algorithm="HS256") - - -def decode_token(token: str) -> TokenPayload | TokenError: - """Decode and validate a JWT token. - - Returns: - TokenPayload if valid, or a specific TokenError variant. - """ - config = get_auth_config() - try: - payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"]) - return TokenPayload(**payload) - except jwt.ExpiredSignatureError: - return TokenError.EXPIRED - except jwt.InvalidSignatureError: - return TokenError.INVALID_SIGNATURE - except jwt.PyJWTError: - return TokenError.MALFORMED diff --git a/backend/app/gateway/auth/local_provider.py b/backend/app/gateway/auth/local_provider.py deleted file mode 100644 index 8bfd15e59..000000000 --- a/backend/app/gateway/auth/local_provider.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Local email/password authentication provider.""" - -from app.gateway.auth.models import User -from app.gateway.auth.password import hash_password_async, verify_password_async -from app.gateway.auth.providers import AuthProvider -from app.gateway.auth.repositories.base import UserRepository - - -class LocalAuthProvider(AuthProvider): - """Email/password authentication provider using local database.""" - - def __init__(self, repository: UserRepository): - """Initialize with a UserRepository. - - Args: - repository: UserRepository implementation (SQLite) - """ - self._repo = repository - - async def authenticate(self, credentials: dict) -> User | None: - """Authenticate with email and password. - - Args: - credentials: dict with 'email' and 'password' keys - - Returns: - User if authentication succeeds, None otherwise - """ - email = credentials.get("email") - password = credentials.get("password") - - if not email or not password: - return None - - user = await self._repo.get_user_by_email(email) - if user is None: - return None - - if user.password_hash is None: - # OAuth user without local password - return None - - if not await verify_password_async(password, user.password_hash): - return None - - return user - - async def get_user(self, user_id: str) -> User | None: - """Get user by ID.""" - return await self._repo.get_user_by_id(user_id) - - async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User: - """Create a new local user. - - Args: - email: User email address - password: Plain text password (will be hashed) - system_role: Role to assign ("admin" or "user") - needs_setup: If True, user must complete setup on first login - - Returns: - Created User instance - """ - password_hash = await hash_password_async(password) if password else None - user = User( - email=email, - password_hash=password_hash, - system_role=system_role, - needs_setup=needs_setup, - ) - return await self._repo.create_user(user) - - async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: - """Get user by OAuth provider and ID.""" - return await self._repo.get_user_by_oauth(provider, oauth_id) - - async def count_users(self) -> int: - """Return total number of registered users.""" - return await self._repo.count_users() - - async def count_admin_users(self) -> int: - """Return number of admin users.""" - return await self._repo.count_admin_users() - - async def update_user(self, user: User) -> User: - """Update an existing user.""" - return await self._repo.update_user(user) - - async def get_user_by_email(self, email: str) -> User | None: - """Get user by email.""" - return await self._repo.get_user_by_email(email) diff --git a/backend/app/gateway/auth/models.py b/backend/app/gateway/auth/models.py deleted file mode 100644 index d8f9b954a..000000000 --- a/backend/app/gateway/auth/models.py +++ /dev/null @@ -1,41 +0,0 @@ -"""User Pydantic models for authentication.""" - -from datetime import UTC, datetime -from typing import Literal -from uuid import UUID, uuid4 - -from pydantic import BaseModel, ConfigDict, EmailStr, Field - - -def _utc_now() -> datetime: - """Return current UTC time (timezone-aware).""" - return datetime.now(UTC) - - -class User(BaseModel): - """Internal user representation.""" - - model_config = ConfigDict(from_attributes=True) - - id: UUID = Field(default_factory=uuid4, description="Primary key") - email: EmailStr = Field(..., description="Unique email address") - password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users") - system_role: Literal["admin", "user"] = Field(default="user") - created_at: datetime = Field(default_factory=_utc_now) - - # OAuth linkage (optional) - oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'") - oauth_id: str | None = Field(None, description="User ID from OAuth provider") - - # Auth lifecycle - needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes") - token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs") - - -class UserResponse(BaseModel): - """Response model for user info endpoint.""" - - id: str - email: str - system_role: Literal["admin", "user"] - needs_setup: bool = False diff --git a/backend/app/gateway/auth/password.py b/backend/app/gateway/auth/password.py deleted file mode 100644 index 588b7a643..000000000 --- a/backend/app/gateway/auth/password.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Password hashing utilities using bcrypt directly.""" - -import asyncio - -import bcrypt - - -def hash_password(password: str) -> str: - """Hash a password using bcrypt.""" - return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - """Verify a password against its hash.""" - return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8")) - - -async def hash_password_async(password: str) -> str: - """Hash a password using bcrypt (non-blocking). - - Wraps the blocking bcrypt operation in a thread pool to avoid - blocking the event loop during password hashing. - """ - return await asyncio.to_thread(hash_password, password) - - -async def verify_password_async(plain_password: str, hashed_password: str) -> bool: - """Verify a password against its hash (non-blocking). - - Wraps the blocking bcrypt operation in a thread pool to avoid - blocking the event loop during password verification. - """ - return await asyncio.to_thread(verify_password, plain_password, hashed_password) diff --git a/backend/app/gateway/auth/providers.py b/backend/app/gateway/auth/providers.py deleted file mode 100644 index 25e782ce3..000000000 --- a/backend/app/gateway/auth/providers.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Auth provider abstraction.""" - -from abc import ABC, abstractmethod - - -class AuthProvider(ABC): - """Abstract base class for authentication providers.""" - - @abstractmethod - async def authenticate(self, credentials: dict) -> "User | None": - """Authenticate user with given credentials. - - Returns User if authentication succeeds, None otherwise. - """ - ... - - @abstractmethod - async def get_user(self, user_id: str) -> "User | None": - """Retrieve user by ID.""" - ... - - -# Import User at runtime to avoid circular imports -from app.gateway.auth.models import User # noqa: E402 diff --git a/backend/app/gateway/auth/repositories/__init__.py b/backend/app/gateway/auth/repositories/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/app/gateway/auth/repositories/base.py b/backend/app/gateway/auth/repositories/base.py deleted file mode 100644 index d96753171..000000000 --- a/backend/app/gateway/auth/repositories/base.py +++ /dev/null @@ -1,102 +0,0 @@ -"""User repository interface for abstracting database operations.""" - -from abc import ABC, abstractmethod - -from app.gateway.auth.models import User - - -class UserNotFoundError(LookupError): - """Raised when a user repository operation targets a non-existent row. - - Subclass of :class:`LookupError` so callers that already catch - ``LookupError`` for "missing entity" can keep working unchanged, - while specific call sites can pin to this class to distinguish - "concurrent delete during update" from other lookups. - """ - - -class UserRepository(ABC): - """Abstract interface for user data storage. - - Implement this interface to support different storage backends - (SQLite) - """ - - @abstractmethod - async def create_user(self, user: User) -> User: - """Create a new user. - - Args: - user: User object to create - - Returns: - Created User with ID assigned - - Raises: - ValueError: If email already exists - """ - ... - - @abstractmethod - async def get_user_by_id(self, user_id: str) -> User | None: - """Get user by ID. - - Args: - user_id: User UUID as string - - Returns: - User if found, None otherwise - """ - ... - - @abstractmethod - async def get_user_by_email(self, email: str) -> User | None: - """Get user by email. - - Args: - email: User email address - - Returns: - User if found, None otherwise - """ - ... - - @abstractmethod - async def update_user(self, user: User) -> User: - """Update an existing user. - - Args: - user: User object with updated fields - - Returns: - Updated User - - Raises: - UserNotFoundError: If no row exists for ``user.id``. This is - a hard failure (not a no-op) so callers cannot mistake a - concurrent-delete race for a successful update. - """ - ... - - @abstractmethod - async def count_users(self) -> int: - """Return total number of registered users.""" - ... - - @abstractmethod - async def count_admin_users(self) -> int: - """Return number of users with system_role == 'admin'.""" - ... - - @abstractmethod - async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: - """Get user by OAuth provider and ID. - - Args: - provider: OAuth provider name (e.g. 'github', 'google') - oauth_id: User ID from the OAuth provider - - Returns: - User if found, None otherwise - """ - ... diff --git a/backend/app/gateway/auth/repositories/sqlite.py b/backend/app/gateway/auth/repositories/sqlite.py deleted file mode 100644 index 3ee3978e3..000000000 --- a/backend/app/gateway/auth/repositories/sqlite.py +++ /dev/null @@ -1,127 +0,0 @@ -"""SQLAlchemy-backed UserRepository implementation. - -Uses the shared async session factory from -``deerflow.persistence.engine`` — the ``users`` table lives in the -same database as ``threads_meta``, ``runs``, ``run_events``, and -``feedback``. - -Constructor takes the session factory directly (same pattern as the -other four repositories in ``deerflow.persistence.*``). Callers -construct this after ``init_engine_from_config()`` has run. -""" - -from __future__ import annotations - -from datetime import UTC -from uuid import UUID - -from sqlalchemy import func, select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -from app.gateway.auth.models import User -from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository -from deerflow.persistence.user.model import UserRow - - -class SQLiteUserRepository(UserRepository): - """Async user repository backed by the shared SQLAlchemy engine.""" - - def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: - self._sf = session_factory - - # ── Converters ──────────────────────────────────────────────────── - - @staticmethod - def _row_to_user(row: UserRow) -> User: - return User( - id=UUID(row.id), - email=row.email, - password_hash=row.password_hash, - system_role=row.system_role, # type: ignore[arg-type] - # SQLite loses tzinfo on read; reattach UTC so downstream - # code can compare timestamps reliably. - created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC), - oauth_provider=row.oauth_provider, - oauth_id=row.oauth_id, - needs_setup=row.needs_setup, - token_version=row.token_version, - ) - - @staticmethod - def _user_to_row(user: User) -> UserRow: - return UserRow( - id=str(user.id), - email=user.email, - password_hash=user.password_hash, - system_role=user.system_role, - created_at=user.created_at, - oauth_provider=user.oauth_provider, - oauth_id=user.oauth_id, - needs_setup=user.needs_setup, - token_version=user.token_version, - ) - - # ── CRUD ────────────────────────────────────────────────────────── - - async def create_user(self, user: User) -> User: - """Insert a new user. Raises ``ValueError`` on duplicate email.""" - row = self._user_to_row(user) - async with self._sf() as session: - session.add(row) - try: - await session.commit() - except IntegrityError as exc: - await session.rollback() - raise ValueError(f"Email already registered: {user.email}") from exc - return user - - async def get_user_by_id(self, user_id: str) -> User | None: - async with self._sf() as session: - row = await session.get(UserRow, user_id) - return self._row_to_user(row) if row is not None else None - - async def get_user_by_email(self, email: str) -> User | None: - stmt = select(UserRow).where(UserRow.email == email) - async with self._sf() as session: - result = await session.execute(stmt) - row = result.scalar_one_or_none() - return self._row_to_user(row) if row is not None else None - - async def update_user(self, user: User) -> User: - async with self._sf() as session: - row = await session.get(UserRow, str(user.id)) - if row is None: - # Hard fail on concurrent delete: callers (reset_admin, - # password change handlers, _ensure_admin_user) all - # fetched the user just before this call, so a missing - # row here means the row vanished underneath us. Silent - # success would let the caller log "password reset" for - # a row that no longer exists. - raise UserNotFoundError(f"User {user.id} no longer exists") - row.email = user.email - row.password_hash = user.password_hash - row.system_role = user.system_role - row.oauth_provider = user.oauth_provider - row.oauth_id = user.oauth_id - row.needs_setup = user.needs_setup - row.token_version = user.token_version - await session.commit() - return user - - async def count_users(self) -> int: - stmt = select(func.count()).select_from(UserRow) - async with self._sf() as session: - return await session.scalar(stmt) or 0 - - async def count_admin_users(self) -> int: - stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin") - async with self._sf() as session: - return await session.scalar(stmt) or 0 - - async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: - stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id) - async with self._sf() as session: - result = await session.execute(stmt) - row = result.scalar_one_or_none() - return self._row_to_user(row) if row is not None else None diff --git a/backend/app/gateway/auth/reset_admin.py b/backend/app/gateway/auth/reset_admin.py deleted file mode 100644 index 7b7da74d0..000000000 --- a/backend/app/gateway/auth/reset_admin.py +++ /dev/null @@ -1,91 +0,0 @@ -"""CLI tool to reset an admin password. - -Usage: - python -m app.gateway.auth.reset_admin - python -m app.gateway.auth.reset_admin --email admin@example.com - -Writes the new password to ``.deer-flow/admin_initial_credentials.txt`` -(mode 0600) instead of printing it, so CI / log aggregators never see -the cleartext secret. -""" - -from __future__ import annotations - -import argparse -import asyncio -import secrets -import sys - -from sqlalchemy import select - -from app.gateway.auth.credential_file import write_initial_credentials -from app.gateway.auth.password import hash_password -from app.gateway.auth.repositories.sqlite import SQLiteUserRepository -from deerflow.persistence.user.model import UserRow - - -async def _run(email: str | None) -> int: - from deerflow.config import get_app_config - from deerflow.persistence.engine import ( - close_engine, - get_session_factory, - init_engine_from_config, - ) - - config = get_app_config() - await init_engine_from_config(config.database) - try: - sf = get_session_factory() - if sf is None: - print("Error: persistence engine not available (check config.database).", file=sys.stderr) - return 1 - - repo = SQLiteUserRepository(sf) - - if email: - user = await repo.get_user_by_email(email) - else: - # Find first admin via direct SELECT — repository does not - # expose a "first admin" helper and we do not want to add - # one just for this CLI. - async with sf() as session: - stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1) - row = (await session.execute(stmt)).scalar_one_or_none() - if row is None: - user = None - else: - user = await repo.get_user_by_id(row.id) - - if user is None: - if email: - print(f"Error: user '{email}' not found.", file=sys.stderr) - else: - print("Error: no admin user found.", file=sys.stderr) - return 1 - - new_password = secrets.token_urlsafe(16) - user.password_hash = hash_password(new_password) - user.token_version += 1 - user.needs_setup = True - await repo.update_user(user) - - cred_path = write_initial_credentials(user.email, new_password, label="reset") - print(f"Password reset for: {user.email}") - print(f"Credentials written to: {cred_path} (mode 0600)") - print("Next login will require setup (new email + password).") - return 0 - finally: - await close_engine() - - -def main() -> None: - parser = argparse.ArgumentParser(description="Reset admin password") - parser.add_argument("--email", help="Admin email (default: first admin found)") - args = parser.parse_args() - - exit_code = asyncio.run(_run(args.email)) - sys.exit(exit_code) - - -if __name__ == "__main__": - main() diff --git a/backend/app/gateway/auth_middleware.py b/backend/app/gateway/auth_middleware.py deleted file mode 100644 index fd982cd79..000000000 --- a/backend/app/gateway/auth_middleware.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Global authentication middleware — fail-closed safety net. - -Rejects unauthenticated requests to non-public paths with 401. When a -request passes the cookie check, resolves the JWT payload to a real -``User`` object and stamps it into both ``request.state.user`` and the -``deerflow.runtime.user_context`` contextvar so that repository-layer -owner filtering works automatically via the sentinel pattern. - -Fine-grained permission checks remain in authz.py decorators. -""" - -from collections.abc import Callable - -from fastapi import HTTPException, Request, Response -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import JSONResponse -from starlette.types import ASGIApp - -from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse -from app.gateway.authz import _ALL_PERMISSIONS, AuthContext -from deerflow.runtime.user_context import reset_current_user, set_current_user - -# Paths that never require authentication. -_PUBLIC_PATH_PREFIXES: tuple[str, ...] = ( - "/health", - "/docs", - "/redoc", - "/openapi.json", -) - -# Exact auth paths that are public (login/register/status check). -# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public. -_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset( - { - "/api/v1/auth/login/local", - "/api/v1/auth/register", - "/api/v1/auth/logout", - "/api/v1/auth/setup-status", - "/api/v1/auth/initialize", - } -) - - -def _is_public(path: str) -> bool: - stripped = path.rstrip("/") - if stripped in _PUBLIC_EXACT_PATHS: - return True - return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES) - - -class AuthMiddleware(BaseHTTPMiddleware): - """Strict auth gate: reject requests without a valid session. - - Two-stage check for non-public paths: - - 1. Cookie presence — return 401 NOT_AUTHENTICATED if missing - 2. JWT validation via ``get_optional_user_from_request`` — return 401 - TOKEN_INVALID if the token is absent, malformed, expired, or the - signed user does not exist / is stale - - On success, stamps ``request.state.user`` and the - ``deerflow.runtime.user_context`` contextvar so that repository-layer - owner filters work downstream without every route needing a - ``@require_auth`` decorator. Routes that need per-resource - authorization (e.g. "user A cannot read user B's thread by guessing - the URL") should additionally use ``@require_permission(..., - owner_check=True)`` for explicit enforcement — but authentication - itself is fully handled here. - """ - - def __init__(self, app: ASGIApp) -> None: - super().__init__(app) - - async def dispatch(self, request: Request, call_next: Callable) -> Response: - if _is_public(request.url.path): - return await call_next(request) - - # Non-public path: require session cookie - if not request.cookies.get("access_token"): - return JSONResponse( - status_code=401, - content={ - "detail": AuthErrorResponse( - code=AuthErrorCode.NOT_AUTHENTICATED, - message="Authentication required", - ).model_dump() - }, - ) - - # Strict JWT validation: reject junk/expired tokens with 401 - # right here instead of silently passing through. This closes - # the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8): - # without this, non-isolation routes like /api/models would - # accept any cookie-shaped string as authentication. - # - # We call the *strict* resolver so that fine-grained error - # codes (token_expired, token_invalid, user_not_found, …) - # propagate from AuthErrorCode, not get flattened into one - # generic code. BaseHTTPMiddleware doesn't let HTTPException - # bubble up, so we catch and render it as JSONResponse here. - from app.gateway.deps import get_current_user_from_request - - try: - user = await get_current_user_from_request(request) - except HTTPException as exc: - return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) - - # Stamp both request.state.user (for the contextvar pattern) - # and request.state.auth (so @require_permission's "auth is - # None" branch short-circuits instead of running the entire - # JWT-decode + DB-lookup pipeline a second time per request). - request.state.user = user - request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS) - token = set_current_user(user) - try: - return await call_next(request) - finally: - reset_current_user(token) diff --git a/backend/app/gateway/authz.py b/backend/app/gateway/authz.py deleted file mode 100644 index 5842a24c7..000000000 --- a/backend/app/gateway/authz.py +++ /dev/null @@ -1,262 +0,0 @@ -"""Authorization decorators and context for DeerFlow. - -Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py - -**Usage:** - -1. Use ``@require_auth`` on routes that need authentication -2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks -3. The decorator chain processes from bottom to top - -**Example:** - - @router.get("/{thread_id}") - @require_auth - @require_permission("threads", "read", owner_check=True) - async def get_thread(thread_id: str, request: Request): - # User is authenticated and has threads:read permission - ... - -**Permission Model:** - -- threads:read - View thread -- threads:write - Create/update thread -- threads:delete - Delete thread -- runs:create - Run agent -- runs:read - View run -- runs:cancel - Cancel run -""" - -from __future__ import annotations - -import functools -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar - -from fastapi import HTTPException, Request - -if TYPE_CHECKING: - from app.gateway.auth.models import User - -P = ParamSpec("P") -T = TypeVar("T") - - -# Permission constants -class Permissions: - """Permission constants for resource:action format.""" - - # Threads - THREADS_READ = "threads:read" - THREADS_WRITE = "threads:write" - THREADS_DELETE = "threads:delete" - - # Runs - RUNS_CREATE = "runs:create" - RUNS_READ = "runs:read" - RUNS_CANCEL = "runs:cancel" - - -class AuthContext: - """Authentication context for the current request. - - Stored in request.state.auth after require_auth decoration. - - Attributes: - user: The authenticated user, or None if anonymous - permissions: List of permission strings (e.g., "threads:read") - """ - - __slots__ = ("user", "permissions") - - def __init__(self, user: User | None = None, permissions: list[str] | None = None): - self.user = user - self.permissions = permissions or [] - - @property - def is_authenticated(self) -> bool: - """Check if user is authenticated.""" - return self.user is not None - - def has_permission(self, resource: str, action: str) -> bool: - """Check if context has permission for resource:action. - - Args: - resource: Resource name (e.g., "threads") - action: Action name (e.g., "read") - - Returns: - True if user has permission - """ - permission = f"{resource}:{action}" - return permission in self.permissions - - def require_user(self) -> User: - """Get user or raise 401. - - Raises: - HTTPException 401 if not authenticated - """ - if not self.user: - raise HTTPException(status_code=401, detail="Authentication required") - return self.user - - -def get_auth_context(request: Request) -> AuthContext | None: - """Get AuthContext from request state.""" - return getattr(request.state, "auth", None) - - -_ALL_PERMISSIONS: list[str] = [ - Permissions.THREADS_READ, - Permissions.THREADS_WRITE, - Permissions.THREADS_DELETE, - Permissions.RUNS_CREATE, - Permissions.RUNS_READ, - Permissions.RUNS_CANCEL, -] - - -async def _authenticate(request: Request) -> AuthContext: - """Authenticate request and return AuthContext. - - Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline. - Returns AuthContext with user=None for anonymous requests. - """ - from app.gateway.deps import get_optional_user_from_request - - user = await get_optional_user_from_request(request) - if user is None: - return AuthContext(user=None, permissions=[]) - - # In future, permissions could be stored in user record - return AuthContext(user=user, permissions=_ALL_PERMISSIONS) - - -def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]: - """Decorator that authenticates the request and sets AuthContext. - - Must be placed ABOVE other decorators (executes after them). - - Usage: - @router.get("/{thread_id}") - @require_auth # Bottom decorator (executes first after permission check) - @require_permission("threads", "read") - async def get_thread(thread_id: str, request: Request): - auth: AuthContext = request.state.auth - ... - - Raises: - ValueError: If 'request' parameter is missing - """ - - @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> Any: - request = kwargs.get("request") - if request is None: - raise ValueError("require_auth decorator requires 'request' parameter") - - # Authenticate and set context - auth_context = await _authenticate(request) - request.state.auth = auth_context - - return await func(*args, **kwargs) - - return wrapper - - -def require_permission( - resource: str, - action: str, - owner_check: bool = False, - require_existing: bool = False, -) -> Callable[[Callable[P, T]], Callable[P, T]]: - """Decorator that checks permission for resource:action. - - Must be used AFTER @require_auth. - - Args: - resource: Resource name (e.g., "threads", "runs") - action: Action name (e.g., "read", "write", "delete") - owner_check: If True, validates that the current user owns the resource. - Requires 'thread_id' path parameter and performs ownership check. - require_existing: Only meaningful with ``owner_check=True``. If True, a - missing ``threads_meta`` row counts as a denial (404) - instead of "untracked legacy thread, allow". Use on - **destructive / mutating** routes (DELETE, PATCH, - state-update) so a deleted thread can't be re-targeted - by another user via the missing-row code path. - - Usage: - # Read-style: legacy untracked threads are allowed - @require_permission("threads", "read", owner_check=True) - async def get_thread(thread_id: str, request: Request): - ... - - # Destructive: thread row MUST exist and be owned by caller - @require_permission("threads", "delete", owner_check=True, require_existing=True) - async def delete_thread(thread_id: str, request: Request): - ... - - Raises: - HTTPException 401: If authentication required but user is anonymous - HTTPException 403: If user lacks permission - HTTPException 404: If owner_check=True but user doesn't own the thread - ValueError: If owner_check=True but 'thread_id' parameter is missing - """ - - def decorator(func: Callable[P, T]) -> Callable[P, T]: - @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> Any: - request = kwargs.get("request") - if request is None: - raise ValueError("require_permission decorator requires 'request' parameter") - - auth: AuthContext = getattr(request.state, "auth", None) - if auth is None: - auth = await _authenticate(request) - request.state.auth = auth - - if not auth.is_authenticated: - raise HTTPException(status_code=401, detail="Authentication required") - - # Check permission - if not auth.has_permission(resource, action): - raise HTTPException( - status_code=403, - detail=f"Permission denied: {resource}:{action}", - ) - - # Owner check for thread-specific resources. - # - # 2.0-rc moved thread metadata into the SQL persistence layer - # (``threads_meta`` table). We verify ownership via - # ``ThreadMetaStore.check_access``: it returns True for - # missing rows (untracked legacy thread) and for rows whose - # ``user_id`` is NULL (shared / pre-auth data), so this is - # strict-deny rather than strict-allow — only an *existing* - # row with a *different* user_id triggers 404. - if owner_check: - thread_id = kwargs.get("thread_id") - if thread_id is None: - raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") - - from app.gateway.deps import get_thread_store - - thread_store = get_thread_store(request) - allowed = await thread_store.check_access( - thread_id, - str(auth.user.id), - require_existing=require_existing, - ) - if not allowed: - raise HTTPException( - status_code=404, - detail=f"Thread {thread_id} not found", - ) - - return await func(*args, **kwargs) - - return wrapper - - return decorator diff --git a/backend/app/gateway/csrf_middleware.py b/backend/app/gateway/csrf_middleware.py deleted file mode 100644 index 4c9b0f36a..000000000 --- a/backend/app/gateway/csrf_middleware.py +++ /dev/null @@ -1,113 +0,0 @@ -"""CSRF protection middleware for FastAPI. - -Per RFC-001: -State-changing operations require CSRF protection. -""" - -import secrets -from collections.abc import Callable - -from fastapi import Request, Response -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import JSONResponse -from starlette.types import ASGIApp - -CSRF_COOKIE_NAME = "csrf_token" -CSRF_HEADER_NAME = "X-CSRF-Token" -CSRF_TOKEN_LENGTH = 64 # bytes - - -def is_secure_request(request: Request) -> bool: - """Detect whether the original client request was made over HTTPS.""" - return request.headers.get("x-forwarded-proto", request.url.scheme) == "https" - - -def generate_csrf_token() -> str: - """Generate a secure random CSRF token.""" - return secrets.token_urlsafe(CSRF_TOKEN_LENGTH) - - -def should_check_csrf(request: Request) -> bool: - """Determine if a request needs CSRF validation. - - CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH). - GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231. - """ - if request.method not in ("POST", "PUT", "DELETE", "PATCH"): - return False - - path = request.url.path.rstrip("/") - # Exempt /api/v1/auth/me endpoint - if path == "/api/v1/auth/me": - return False - return True - - -_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset( - { - "/api/v1/auth/login/local", - "/api/v1/auth/logout", - "/api/v1/auth/register", - "/api/v1/auth/initialize", - } -) - - -def is_auth_endpoint(request: Request) -> bool: - """Check if the request is to an auth endpoint. - - Auth endpoints don't need CSRF validation on first call (no token). - """ - return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS - - -class CSRFMiddleware(BaseHTTPMiddleware): - """Middleware that implements CSRF protection using Double Submit Cookie pattern.""" - - def __init__(self, app: ASGIApp) -> None: - super().__init__(app) - - async def dispatch(self, request: Request, call_next: Callable) -> Response: - _is_auth = is_auth_endpoint(request) - - if should_check_csrf(request) and not _is_auth: - cookie_token = request.cookies.get(CSRF_COOKIE_NAME) - header_token = request.headers.get(CSRF_HEADER_NAME) - - if not cookie_token or not header_token: - return JSONResponse( - status_code=403, - content={"detail": "CSRF token missing. Include X-CSRF-Token header."}, - ) - - if not secrets.compare_digest(cookie_token, header_token): - return JSONResponse( - status_code=403, - content={"detail": "CSRF token mismatch."}, - ) - - response = await call_next(request) - - # For auth endpoints that set up session, also set CSRF cookie - if _is_auth and request.method == "POST": - # Generate a new CSRF token for the session - csrf_token = generate_csrf_token() - is_https = is_secure_request(request) - response.set_cookie( - key=CSRF_COOKIE_NAME, - value=csrf_token, - httponly=False, # Must be JS-readable for Double Submit Cookie pattern - secure=is_https, - samesite="strict", - ) - - return response - - -def get_csrf_token(request: Request) -> str | None: - """Get the CSRF token from the current request's cookies. - - This is useful for server-side rendering where you need to embed - token in forms or headers. - """ - return request.cookies.get(CSRF_COOKIE_NAME) diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py deleted file mode 100644 index 20da78af9..000000000 --- a/backend/app/gateway/deps.py +++ /dev/null @@ -1,234 +0,0 @@ -"""Centralized accessors for singleton objects stored on ``app.state``. - -**Getters** (used by routers): raise 503 when a required dependency is -missing, except ``get_store`` which returns ``None``. - -Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. -""" - -from __future__ import annotations - -from collections.abc import AsyncGenerator, Callable -from contextlib import AsyncExitStack, asynccontextmanager -from typing import TYPE_CHECKING, TypeVar, cast - -from fastapi import FastAPI, HTTPException, Request -from langgraph.types import Checkpointer - -from deerflow.persistence.feedback import FeedbackRepository -from deerflow.runtime import RunContext, RunManager, StreamBridge -from deerflow.runtime.events.store.base import RunEventStore -from deerflow.runtime.runs.store.base import RunStore - -if TYPE_CHECKING: - from app.gateway.auth.local_provider import LocalAuthProvider - from app.gateway.auth.repositories.sqlite import SQLiteUserRepository - from deerflow.persistence.thread_meta.base import ThreadMetaStore - - -T = TypeVar("T") - - -@asynccontextmanager -async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: - """Bootstrap and tear down all LangGraph runtime singletons. - - Usage in ``app.py``:: - - async with langgraph_runtime(app): - yield - """ - from deerflow.config import get_app_config - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config - from deerflow.runtime import make_store, make_stream_bridge - from deerflow.runtime.checkpointer.async_provider import make_checkpointer - from deerflow.runtime.events.store import make_run_event_store - - async with AsyncExitStack() as stack: - app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge()) - - # Initialize persistence engine BEFORE checkpointer so that - # auto-create-database logic runs first (postgres backend). - config = get_app_config() - await init_engine_from_config(config.database) - - app.state.checkpointer = await stack.enter_async_context(make_checkpointer()) - app.state.store = await stack.enter_async_context(make_store()) - - # Initialize repositories — one get_session_factory() call for all. - sf = get_session_factory() - if sf is not None: - from deerflow.persistence.feedback import FeedbackRepository - from deerflow.persistence.run import RunRepository - - app.state.run_store = RunRepository(sf) - app.state.feedback_repo = FeedbackRepository(sf) - else: - from deerflow.runtime.runs.store.memory import MemoryRunStore - - app.state.run_store = MemoryRunStore() - app.state.feedback_repo = None - - from deerflow.persistence.thread_meta import make_thread_store - - app.state.thread_store = make_thread_store(sf, app.state.store) - - # Run event store (has its own factory with config-driven backend selection) - run_events_config = getattr(config, "run_events", None) - app.state.run_event_store = make_run_event_store(run_events_config) - - # RunManager with store backing for persistence - app.state.run_manager = RunManager(store=app.state.run_store) - - try: - yield - finally: - await close_engine() - - -# --------------------------------------------------------------------------- -# Getters – called by routers per-request -# --------------------------------------------------------------------------- - - -def _require(attr: str, label: str) -> Callable[[Request], T]: - """Create a FastAPI dependency that returns ``app.state.`` or 503.""" - - def dep(request: Request) -> T: - val = getattr(request.app.state, attr, None) - if val is None: - raise HTTPException(status_code=503, detail=f"{label} not available") - return cast(T, val) - - dep.__name__ = dep.__qualname__ = f"get_{attr}" - return dep - - -get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge") -get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager") -get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer") -get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store") -get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback") -get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store") - - -def get_store(request: Request): - """Return the global store (may be ``None`` if not configured).""" - return getattr(request.app.state, "store", None) - - -def get_thread_store(request: Request) -> ThreadMetaStore: - """Return the thread metadata store (SQL or memory-backed).""" - val = getattr(request.app.state, "thread_store", None) - if val is None: - raise HTTPException(status_code=503, detail="Thread metadata store not available") - return val - - -def get_run_context(request: Request) -> RunContext: - """Build a :class:`RunContext` from ``app.state`` singletons. - - Returns a *base* context with infrastructure dependencies. - """ - from deerflow.config import get_app_config - - return RunContext( - checkpointer=get_checkpointer(request), - store=get_store(request), - event_store=get_run_event_store(request), - run_events_config=getattr(get_app_config(), "run_events", None), - thread_store=get_thread_store(request), - ) - - -# --------------------------------------------------------------------------- -# Auth helpers (used by authz.py and auth middleware) -# --------------------------------------------------------------------------- - -# Cached singletons to avoid repeated instantiation per request -_cached_local_provider: LocalAuthProvider | None = None -_cached_repo: SQLiteUserRepository | None = None - - -def get_local_provider() -> LocalAuthProvider: - """Get or create the cached LocalAuthProvider singleton. - - Must be called after ``init_engine_from_config()`` — the shared - session factory is required to construct the user repository. - """ - global _cached_local_provider, _cached_repo - if _cached_repo is None: - from app.gateway.auth.repositories.sqlite import SQLiteUserRepository - from deerflow.persistence.engine import get_session_factory - - sf = get_session_factory() - if sf is None: - raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table") - _cached_repo = SQLiteUserRepository(sf) - if _cached_local_provider is None: - from app.gateway.auth.local_provider import LocalAuthProvider - - _cached_local_provider = LocalAuthProvider(repository=_cached_repo) - return _cached_local_provider - - -async def get_current_user_from_request(request: Request): - """Get the current authenticated user from the request cookie. - - Raises HTTPException 401 if not authenticated. - """ - from app.gateway.auth import decode_token - from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code - - access_token = request.cookies.get("access_token") - if not access_token: - raise HTTPException( - status_code=401, - detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(), - ) - - payload = decode_token(access_token) - if isinstance(payload, TokenError): - raise HTTPException( - status_code=401, - detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(), - ) - - provider = get_local_provider() - user = await provider.get_user(payload.sub) - if user is None: - raise HTTPException( - status_code=401, - detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(), - ) - - # Token version mismatch → password was changed, token is stale - if user.token_version != payload.ver: - raise HTTPException( - status_code=401, - detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(), - ) - - return user - - -async def get_optional_user_from_request(request: Request): - """Get optional authenticated user from request. - - Returns None if not authenticated. - """ - try: - return await get_current_user_from_request(request) - except HTTPException: - return None - - -async def get_current_user(request: Request) -> str | None: - """Extract user_id from request cookie, or None if not authenticated. - - Thin adapter that returns the string id for callers that only need - identification (e.g., ``feedback.py``). Full-user callers should use - ``get_current_user_from_request`` or ``get_optional_user_from_request``. - """ - user = await get_optional_user_from_request(request) - return str(user.id) if user else None diff --git a/backend/app/gateway/langgraph_auth.py b/backend/app/gateway/langgraph_auth.py deleted file mode 100644 index 06074b9b8..000000000 --- a/backend/app/gateway/langgraph_auth.py +++ /dev/null @@ -1,106 +0,0 @@ -"""LangGraph Server auth handler — shares JWT logic with Gateway. - -Loaded by LangGraph Server via langgraph.json ``auth.path``. -Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway, -so both modes validate tokens with the same secret and rules. - -Two layers: - 1. @auth.authenticate — validates JWT cookie, extracts user_id, - and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH) - 2. @auth.on — returns metadata filter so each user only sees own threads -""" - -import secrets - -from langgraph_sdk import Auth - -from app.gateway.auth.errors import TokenError -from app.gateway.auth.jwt import decode_token -from app.gateway.deps import get_local_provider - -auth = Auth() - -# Methods that require CSRF validation (state-changing per RFC 7231). -_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"}) - - -def _check_csrf(request) -> None: - """Enforce Double Submit Cookie CSRF check for state-changing requests. - - Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes - proxied directly by nginx have the same CSRF protection. - """ - method = getattr(request, "method", "") or "" - if method.upper() not in _CSRF_METHODS: - return - - cookie_token = request.cookies.get("csrf_token") - header_token = request.headers.get("x-csrf-token") - - if not cookie_token or not header_token: - raise Auth.exceptions.HTTPException( - status_code=403, - detail="CSRF token missing. Include X-CSRF-Token header.", - ) - - if not secrets.compare_digest(cookie_token, header_token): - raise Auth.exceptions.HTTPException( - status_code=403, - detail="CSRF token mismatch.", - ) - - -@auth.authenticate -async def authenticate(request): - """Validate the session cookie, decode JWT, and check token_version. - - Same validation chain as Gateway's get_current_user_from_request: - cookie → decode JWT → DB lookup → token_version match - Also enforces CSRF on state-changing methods. - """ - # CSRF check before authentication so forged cross-site requests - # are rejected early, even if the cookie carries a valid JWT. - _check_csrf(request) - - token = request.cookies.get("access_token") - if not token: - raise Auth.exceptions.HTTPException( - status_code=401, - detail="Not authenticated", - ) - - payload = decode_token(token) - if isinstance(payload, TokenError): - raise Auth.exceptions.HTTPException( - status_code=401, - detail=f"Token error: {payload.value}", - ) - - user = await get_local_provider().get_user(payload.sub) - if user is None: - raise Auth.exceptions.HTTPException( - status_code=401, - detail="User not found", - ) - if user.token_version != payload.ver: - raise Auth.exceptions.HTTPException( - status_code=401, - detail="Token revoked (password changed)", - ) - - return payload.sub - - -@auth.on -async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict): - """Inject user_id metadata on writes; filter by user_id on reads. - - Gateway stores thread ownership as ``metadata.user_id``. - This handler ensures LangGraph Server enforces the same isolation. - """ - # On create/update: stamp user_id into metadata - metadata = value.setdefault("metadata", {}) - metadata["user_id"] = ctx.user.identity - - # Return filter dict — LangGraph applies it to search/read/delete - return {"user_id": ctx.user.identity} diff --git a/backend/app/gateway/routers/assistants_compat.py b/backend/app/gateway/routers/assistants_compat.py deleted file mode 100644 index 83708747c..000000000 --- a/backend/app/gateway/routers/assistants_compat.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Assistants compatibility endpoints. - -Provides LangGraph Platform-compatible assistants API backed by the -``langgraph.json`` graph registry and ``config.yaml`` agent definitions. - -This is a minimal stub that satisfies the ``useStream`` React hook's -initialization requirements (``assistants.search()`` and ``assistants.get()``). -""" - -from __future__ import annotations - -import logging -from datetime import UTC, datetime -from typing import Any - -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field - -logger = logging.getLogger(__name__) -router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"]) - - -class AssistantResponse(BaseModel): - assistant_id: str - graph_id: str - name: str - config: dict[str, Any] = Field(default_factory=dict) - metadata: dict[str, Any] = Field(default_factory=dict) - description: str | None = None - created_at: str = "" - updated_at: str = "" - version: int = 1 - - -class AssistantSearchRequest(BaseModel): - graph_id: str | None = None - name: str | None = None - metadata: dict[str, Any] | None = None - limit: int = 10 - offset: int = 0 - - -def _get_default_assistant() -> AssistantResponse: - """Return the default lead_agent assistant.""" - now = datetime.now(UTC).isoformat() - return AssistantResponse( - assistant_id="lead_agent", - graph_id="lead_agent", - name="lead_agent", - config={}, - metadata={"created_by": "system"}, - description="DeerFlow lead agent", - created_at=now, - updated_at=now, - version=1, - ) - - -def _list_assistants() -> list[AssistantResponse]: - """List all available assistants from config.""" - assistants = [_get_default_assistant()] - - # Also include custom agents from config.yaml agents directory - try: - from deerflow.config.agents_config import list_custom_agents - - for agent_cfg in list_custom_agents(): - now = datetime.now(UTC).isoformat() - assistants.append( - AssistantResponse( - assistant_id=agent_cfg.name, - graph_id="lead_agent", # All agents use the same graph - name=agent_cfg.name, - config={}, - metadata={"created_by": "user"}, - description=agent_cfg.description or "", - created_at=now, - updated_at=now, - version=1, - ) - ) - except Exception: - logger.debug("Could not load custom agents for assistants list") - - return assistants - - -@router.post("/search", response_model=list[AssistantResponse]) -async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]: - """Search assistants. - - Returns all registered assistants (lead_agent + custom agents from config). - """ - assistants = _list_assistants() - - if body and body.graph_id: - assistants = [a for a in assistants if a.graph_id == body.graph_id] - if body and body.name: - assistants = [a for a in assistants if body.name.lower() in a.name.lower()] - - offset = body.offset if body else 0 - limit = body.limit if body else 10 - return assistants[offset : offset + limit] - - -@router.get("/{assistant_id}", response_model=AssistantResponse) -async def get_assistant_compat(assistant_id: str) -> AssistantResponse: - """Get an assistant by ID.""" - for a in _list_assistants(): - if a.assistant_id == assistant_id: - return a - raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found") - - -@router.get("/{assistant_id}/graph") -async def get_assistant_graph(assistant_id: str) -> dict: - """Get the graph structure for an assistant. - - Returns a minimal graph description. Full graph introspection is - not supported in the Gateway — this stub satisfies SDK validation. - """ - found = any(a.assistant_id == assistant_id for a in _list_assistants()) - if not found: - raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found") - - return { - "graph_id": "lead_agent", - "nodes": [], - "edges": [], - } - - -@router.get("/{assistant_id}/schemas") -async def get_assistant_schemas(assistant_id: str) -> dict: - """Get JSON schemas for an assistant's input/output/state. - - Returns empty schemas — full introspection not supported in Gateway. - """ - found = any(a.assistant_id == assistant_id for a in _list_assistants()) - if not found: - raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found") - - return { - "graph_id": "lead_agent", - "input_schema": {}, - "output_schema": {}, - "state_schema": {}, - "config_schema": {}, - } diff --git a/backend/app/gateway/routers/auth.py b/backend/app/gateway/routers/auth.py deleted file mode 100644 index 44b996331..000000000 --- a/backend/app/gateway/routers/auth.py +++ /dev/null @@ -1,459 +0,0 @@ -"""Authentication endpoints.""" - -import logging -import os -import time -from ipaddress import ip_address, ip_network - -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status -from fastapi.security import OAuth2PasswordRequestForm -from pydantic import BaseModel, EmailStr, Field, field_validator - -from app.gateway.auth import ( - UserResponse, - create_access_token, -) -from app.gateway.auth.config import get_auth_config -from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse -from app.gateway.csrf_middleware import is_secure_request -from app.gateway.deps import get_current_user_from_request, get_local_provider - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) - - -# ── Request/Response Models ────────────────────────────────────────────── - - -class LoginResponse(BaseModel): - """Response model for login — token only lives in HttpOnly cookie.""" - - expires_in: int # seconds - needs_setup: bool = False - - -# Top common-password blocklist. Drawn from the public SecLists "10k worst -# passwords" set, lowercased + length>=8 only (shorter ones already fail -# the min_length check). Kept tight on purpose: this is the **lower bound** -# defense, not a full HIBP / passlib check, and runs in-process per request. -_COMMON_PASSWORDS: frozenset[str] = frozenset( - { - "password", - "password1", - "password12", - "password123", - "password1234", - "12345678", - "123456789", - "1234567890", - "qwerty12", - "qwertyui", - "qwerty123", - "abc12345", - "abcd1234", - "iloveyou", - "letmein1", - "welcome1", - "welcome123", - "admin123", - "administrator", - "passw0rd", - "p@ssw0rd", - "monkey12", - "trustno1", - "sunshine", - "princess", - "football", - "baseball", - "superman", - "batman123", - "starwars", - "dragon123", - "master123", - "shadow12", - "michael1", - "jennifer", - "computer", - } -) - - -def _password_is_common(password: str) -> bool: - """Case-insensitive blocklist check. - - Lowercases the input so trivial mutations like ``Password`` / - ``PASSWORD`` are also rejected. Does not normalize digit substitutions - (``p@ssw0rd`` is included as a literal entry instead) — keeping the - rule cheap and predictable. - """ - return password.lower() in _COMMON_PASSWORDS - - -def _validate_strong_password(value: str) -> str: - """Pydantic field-validator body shared by Register + ChangePassword. - - Constraint = function, not type-level mixin. The two request models - have no "is-a" relationship; they only share the password-strength - rule. Lifting it into a free function lets each model bind it via - ``@field_validator(field_name)`` without inheritance gymnastics. - """ - if _password_is_common(value): - raise ValueError("Password is too common; choose a stronger password.") - return value - - -class RegisterRequest(BaseModel): - """Request model for user registration.""" - - email: EmailStr - password: str = Field(..., min_length=8) - - _strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v))) - - -class ChangePasswordRequest(BaseModel): - """Request model for password change (also handles setup flow).""" - - current_password: str - new_password: str = Field(..., min_length=8) - new_email: EmailStr | None = None - - _strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v))) - - -class MessageResponse(BaseModel): - """Generic message response.""" - - message: str - - -# ── Helpers ─────────────────────────────────────────────────────────────── - - -def _set_session_cookie(response: Response, token: str, request: Request) -> None: - """Set the access_token HttpOnly cookie on the response.""" - config = get_auth_config() - is_https = is_secure_request(request) - response.set_cookie( - key="access_token", - value=token, - httponly=True, - secure=is_https, - samesite="lax", - max_age=config.token_expiry_days * 24 * 3600 if is_https else None, - ) - - -# ── Rate Limiting ──────────────────────────────────────────────────────── -# In-process dict — not shared across workers. Sufficient for single-worker deployments. - -_MAX_LOGIN_ATTEMPTS = 5 -_LOCKOUT_SECONDS = 300 # 5 minutes - -# ip → (fail_count, lock_until_timestamp) -_login_attempts: dict[str, tuple[int, float]] = {} - - -def _trusted_proxies() -> list: - """Parse ``AUTH_TRUSTED_PROXIES`` env var into a list of ip_network objects. - - Comma-separated CIDR or single-IP entries. Empty / unset = no proxy is - trusted (direct mode). Invalid entries are skipped with a logger warning. - Read live so env-var overrides take effect immediately and tests can - ``monkeypatch.setenv`` without poking a module-level cache. - """ - raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip() - if not raw: - return [] - nets = [] - for entry in raw.split(","): - entry = entry.strip() - if not entry: - continue - try: - nets.append(ip_network(entry, strict=False)) - except ValueError: - logger.warning("AUTH_TRUSTED_PROXIES: ignoring invalid entry %r", entry) - return nets - - -def _get_client_ip(request: Request) -> str: - """Extract the real client IP for rate limiting. - - Trust model: - - - The TCP peer (``request.client.host``) is always the baseline. It is - whatever the kernel reports as the connecting socket — unforgeable - by the client itself. - - ``X-Real-IP`` is **only** honored if the TCP peer is in the - ``AUTH_TRUSTED_PROXIES`` allowlist (set via env var, comma-separated - CIDR or single IPs). When set, the gateway is assumed to be behind a - reverse proxy (nginx, Cloudflare, ALB, …) that overwrites - ``X-Real-IP`` with the original client address. - - With no ``AUTH_TRUSTED_PROXIES`` set, ``X-Real-IP`` is silently - ignored — closing the bypass where any client could rotate the - header to dodge per-IP rate limits in dev / direct-gateway mode. - - ``X-Forwarded-For`` is intentionally NOT used because it is naturally - client-controlled at the *first* hop and the trust chain is harder to - audit per-request. - """ - peer_host = request.client.host if request.client else None - - trusted = _trusted_proxies() - if trusted and peer_host: - try: - peer_ip = ip_address(peer_host) - if any(peer_ip in net for net in trusted): - real_ip = request.headers.get("x-real-ip", "").strip() - if real_ip: - return real_ip - except ValueError: - # peer_host wasn't a parseable IP (e.g. "unknown") — fall through - pass - - return peer_host or "unknown" - - -def _check_rate_limit(ip: str) -> None: - """Raise 429 if the IP is currently locked out.""" - record = _login_attempts.get(ip) - if record is None: - return - fail_count, lock_until = record - if fail_count >= _MAX_LOGIN_ATTEMPTS: - if time.time() < lock_until: - raise HTTPException( - status_code=429, - detail="Too many login attempts. Try again later.", - ) - del _login_attempts[ip] - - -_MAX_TRACKED_IPS = 10000 - - -def _record_login_failure(ip: str) -> None: - """Record a failed login attempt for the given IP.""" - # Evict expired lockouts when dict grows too large - if len(_login_attempts) >= _MAX_TRACKED_IPS: - now = time.time() - expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t] - for k in expired: - del _login_attempts[k] - # If still too large, evict cheapest-to-lose half: below-threshold - # IPs (lock_until=0.0) sort first, then earliest-expiring lockouts. - if len(_login_attempts) >= _MAX_TRACKED_IPS: - by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1]) - for k, _ in by_time[: len(by_time) // 2]: - del _login_attempts[k] - - record = _login_attempts.get(ip) - if record is None: - _login_attempts[ip] = (1, 0.0) - else: - new_count = record[0] + 1 - lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0 - _login_attempts[ip] = (new_count, lock_until) - - -def _record_login_success(ip: str) -> None: - """Clear failure counter for the given IP on successful login.""" - _login_attempts.pop(ip, None) - - -# ── Endpoints ───────────────────────────────────────────────────────────── - - -@router.post("/login/local", response_model=LoginResponse) -async def login_local( - request: Request, - response: Response, - form_data: OAuth2PasswordRequestForm = Depends(), -): - """Local email/password login.""" - client_ip = _get_client_ip(request) - _check_rate_limit(client_ip) - - user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password}) - - if user is None: - _record_login_failure(client_ip) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(), - ) - - _record_login_success(client_ip) - token = create_access_token(str(user.id), token_version=user.token_version) - _set_session_cookie(response, token, request) - - return LoginResponse( - expires_in=get_auth_config().token_expiry_days * 24 * 3600, - needs_setup=user.needs_setup, - ) - - -@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) -async def register(request: Request, response: Response, body: RegisterRequest): - """Register a new user account (always 'user' role). - - Admin is auto-created on first boot. This endpoint creates regular users. - Auto-login by setting the session cookie. - """ - try: - user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user") - except ValueError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(), - ) - - token = create_access_token(str(user.id), token_version=user.token_version) - _set_session_cookie(response, token, request) - - return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role) - - -@router.post("/logout", response_model=MessageResponse) -async def logout(request: Request, response: Response): - """Logout current user by clearing the cookie.""" - response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax") - return MessageResponse(message="Successfully logged out") - - -@router.post("/change-password", response_model=MessageResponse) -async def change_password(request: Request, response: Response, body: ChangePasswordRequest): - """Change password for the currently authenticated user. - - Also handles the first-boot setup flow: - - If new_email is provided, updates email (checks uniqueness) - - If user.needs_setup is True and new_email is given, clears needs_setup - - Always increments token_version to invalidate old sessions - - Re-issues session cookie with new token_version - """ - from app.gateway.auth.password import hash_password_async, verify_password_async - - user = await get_current_user_from_request(request) - - if user.password_hash is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump()) - - if not await verify_password_async(body.current_password, user.password_hash): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump()) - - provider = get_local_provider() - - # Update email if provided - if body.new_email is not None: - existing = await provider.get_user_by_email(body.new_email) - if existing and str(existing.id) != str(user.id): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump()) - user.email = body.new_email - - # Update password + bump version - user.password_hash = await hash_password_async(body.new_password) - user.token_version += 1 - - # Clear setup flag if this is the setup flow - if user.needs_setup and body.new_email is not None: - user.needs_setup = False - - await provider.update_user(user) - - # Re-issue cookie with new token_version - token = create_access_token(str(user.id), token_version=user.token_version) - _set_session_cookie(response, token, request) - - return MessageResponse(message="Password changed successfully") - - -@router.get("/me", response_model=UserResponse) -async def get_me(request: Request): - """Get current authenticated user info.""" - user = await get_current_user_from_request(request) - return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup) - - -@router.get("/setup-status") -async def setup_status(): - """Check if an admin account exists. Returns needs_setup=True when no admin exists.""" - admin_count = await get_local_provider().count_admin_users() - return {"needs_setup": admin_count == 0} - - -class InitializeAdminRequest(BaseModel): - """Request model for first-boot admin account creation.""" - - email: EmailStr - password: str = Field(..., min_length=8) - - _strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v))) - - -@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED) -async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest): - """Create the first admin account on initial system setup. - - Only callable when no admin exists. Returns 409 Conflict if an admin - already exists. - - On success, the admin account is created with ``needs_setup=False`` and - the session cookie is set. - """ - admin_count = await get_local_provider().count_admin_users() - if admin_count > 0: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(), - ) - - try: - user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False) - except ValueError: - # DB unique-constraint race: another concurrent request beat us. - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(), - ) - - token = create_access_token(str(user.id), token_version=user.token_version) - _set_session_cookie(response, token, request) - - return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role) - - -# ── OAuth Endpoints (Future/Placeholder) ───────────────────────────────── - - -@router.get("/oauth/{provider}") -async def oauth_login(provider: str): - """Initiate OAuth login flow. - - Redirects to the OAuth provider's authorization URL. - Currently a placeholder - requires OAuth provider implementation. - """ - if provider not in ["github", "google"]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Unsupported OAuth provider: {provider}", - ) - - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="OAuth login not yet implemented", - ) - - -@router.get("/callback/{provider}") -async def oauth_callback(provider: str, code: str, state: str): - """OAuth callback endpoint. - - Handles the OAuth provider's callback after user authorization. - Currently a placeholder. - """ - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, - detail="OAuth callback not yet implemented", - )