diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 01606a8cb..faec6290d 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -1,15 +1,21 @@ import logging +import os from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from datetime import UTC from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from app.gateway.auth_middleware import AuthMiddleware from app.gateway.config import get_gateway_config +from app.gateway.csrf_middleware import CSRFMiddleware from app.gateway.deps import langgraph_runtime from app.gateway.routers import ( agents, artifacts, assistants_compat, + auth, channels, feedback, mcp, @@ -34,6 +40,92 @@ logging.basicConfig( logger = logging.getLogger(__name__) +async def _ensure_admin_user(app: FastAPI) -> None: + """Auto-create the admin user on first boot if no users exist. + + Prints the generated password to stdout so the operator can log in. + On subsequent boots, warns if any user still needs setup. + + Multi-worker safe: relies on SQLite UNIQUE constraint to resolve races. + Only the worker that successfully creates/updates the admin prints the + password; losers silently skip. + """ + import secrets + + from app.gateway.deps import get_local_provider + + provider = get_local_provider() + user_count = await provider.count_users() + + if user_count == 0: + password = secrets.token_urlsafe(16) + try: + admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True) + except ValueError: + return # Another worker already created the admin. + + # Migrate orphaned threads (no owner_id) to this admin + store = getattr(app.state, "store", None) + if store is not None: + await _migrate_orphaned_threads(store, str(admin.id)) + + logger.info("=" * 60) + logger.info(" Admin account created on first boot") + logger.info(" Email: %s", admin.email) + logger.info(" Password: %s", password) + logger.info(" Change it after login: Settings -> Account") + logger.info("=" * 60) + return + + # Admin exists but setup never completed — reset password so operator + # can always find it in the console without needing the CLI. + # Multi-worker guard: if admin was created less than 30s ago, another + # worker just created it and will print the password — skip reset. + admin = await provider.get_user_by_email("admin@deerflow.dev") + if admin and admin.needs_setup: + import time + + age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp() + if age < 30: + return # Just created by another worker in this startup; its password is still valid. + + from app.gateway.auth.password import hash_password_async + + password = secrets.token_urlsafe(16) + admin.password_hash = await hash_password_async(password) + admin.token_version += 1 + await provider.update_user(admin) + + logger.info("=" * 60) + logger.info(" Admin account setup incomplete — password reset") + logger.info(" Email: %s", admin.email) + logger.info(" Password: %s", password) + logger.info(" Change it after login: Settings -> Account") + logger.info("=" * 60) + + +async def _migrate_orphaned_threads(store, admin_user_id: str) -> None: + """Migrate threads with no owner_id to the given admin. + + NOTE: This is the initial port. Commit 5 will replace the hardcoded + limit=1000 with cursor pagination and extend to SQL persistence tables. + """ + try: + migrated = 0 + results = await store.asearch(("threads",), limit=1000) + for item in results: + metadata = item.value.get("metadata", {}) + if not metadata.get("owner_id"): + metadata["owner_id"] = admin_user_id + item.value["metadata"] = metadata + await store.aput(("threads",), item.key, item.value) + migrated += 1 + if migrated: + logger.info("Migrated %d orphaned thread(s) to admin", migrated) + except Exception: + logger.exception("Thread migration failed (non-fatal)") + + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Application lifespan handler.""" @@ -53,6 +145,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async with langgraph_runtime(app): logger.info("LangGraph runtime initialised") + # Ensure admin user exists (auto-create on first boot) + # Must run AFTER langgraph_runtime so app.state.store is available for thread migration + await _ensure_admin_user(app) + # Start IM channel service if any channels are configured try: from app.channels.service import start_channel_service @@ -164,7 +260,35 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an ], ) - # CORS is handled by nginx - no need for FastAPI middleware + # Auth: reject unauthenticated requests to non-public paths (fail-closed safety net) + app.add_middleware(AuthMiddleware) + + # CSRF: Double Submit Cookie pattern for state-changing requests + app.add_middleware(CSRFMiddleware) + + # CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware. + # In production, nginx handles CORS and no middleware is needed. + cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "") + if cors_origins_env: + cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()] + # Validate: wildcard origin with credentials is a security misconfiguration + for origin in cors_origins: + if origin == "*": + logger.error( + "GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. " + "This is a security misconfiguration — browsers will reject the response. " + "Use explicit scheme://host:port origins instead." + ) + cors_origins = [o for o in cors_origins if o != "*"] + break + if cors_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) # Include routers # Models API is mounted at /api/models @@ -200,6 +324,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an # Assistants compatibility API (LangGraph Platform stub) app.include_router(assistants_compat.router) + # Auth API is mounted at /api/v1/auth + app.include_router(auth.router) + # Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback app.include_router(feedback.router) diff --git a/backend/app/gateway/auth_middleware.py b/backend/app/gateway/auth_middleware.py new file mode 100644 index 000000000..cca505688 --- /dev/null +++ b/backend/app/gateway/auth_middleware.py @@ -0,0 +1,71 @@ +"""Global authentication middleware — fail-closed safety net. + +Rejects unauthenticated requests to non-public paths with 401. +Fine-grained permission checks remain in authz.py decorators. +""" + +from collections.abc import Callable + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse +from starlette.types import ASGIApp + +from app.gateway.auth.errors import AuthErrorCode + +# Paths that never require authentication. +_PUBLIC_PATH_PREFIXES: tuple[str, ...] = ( + "/health", + "/docs", + "/redoc", + "/openapi.json", +) + +# Exact auth paths that are public (login/register/status check). +# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public. +_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset( + { + "/api/v1/auth/login/local", + "/api/v1/auth/register", + "/api/v1/auth/logout", + "/api/v1/auth/setup-status", + } +) + + +def _is_public(path: str) -> bool: + stripped = path.rstrip("/") + if stripped in _PUBLIC_EXACT_PATHS: + return True + return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES) + + +class AuthMiddleware(BaseHTTPMiddleware): + """Coarse-grained auth gate: reject requests without a valid session cookie. + + This does NOT verify JWT signature or user existence — that is the job of + ``get_current_user_from_request`` in deps.py (called by ``@require_auth``). + The middleware only checks *presence* of the cookie so that new endpoints + that forget ``@require_auth`` are not completely exposed. + """ + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + if _is_public(request.url.path): + return await call_next(request) + + # Non-public path: require session cookie + if not request.cookies.get("access_token"): + return JSONResponse( + status_code=401, + content={ + "detail": { + "code": AuthErrorCode.NOT_AUTHENTICATED, + "message": "Authentication required", + } + }, + ) + + return await call_next(request) diff --git a/backend/app/gateway/authz.py b/backend/app/gateway/authz.py new file mode 100644 index 000000000..015f747c3 --- /dev/null +++ b/backend/app/gateway/authz.py @@ -0,0 +1,261 @@ +"""Authorization decorators and context for DeerFlow. + +Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py + +**Usage:** + +1. Use ``@require_auth`` on routes that need authentication +2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks +3. The decorator chain processes from bottom to top + +**Example:** + + @router.get("/{thread_id}") + @require_auth + @require_permission("threads", "read", owner_check=True) + async def get_thread(thread_id: str, request: Request): + # User is authenticated and has threads:read permission + ... + +**Permission Model:** + +- threads:read - View thread +- threads:write - Create/update thread +- threads:delete - Delete thread +- runs:create - Run agent +- runs:read - View run +- runs:cancel - Cancel run +""" + +from __future__ import annotations + +import functools +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar + +from fastapi import HTTPException, Request + +if TYPE_CHECKING: + from app.gateway.auth.models import User + +P = ParamSpec("P") +T = TypeVar("T") + + +# Permission constants +class Permissions: + """Permission constants for resource:action format.""" + + # Threads + THREADS_READ = "threads:read" + THREADS_WRITE = "threads:write" + THREADS_DELETE = "threads:delete" + + # Runs + RUNS_CREATE = "runs:create" + RUNS_READ = "runs:read" + RUNS_CANCEL = "runs:cancel" + + +class AuthContext: + """Authentication context for the current request. + + Stored in request.state.auth after require_auth decoration. + + Attributes: + user: The authenticated user, or None if anonymous + permissions: List of permission strings (e.g., "threads:read") + """ + + __slots__ = ("user", "permissions") + + def __init__(self, user: User | None = None, permissions: list[str] | None = None): + self.user = user + self.permissions = permissions or [] + + @property + def is_authenticated(self) -> bool: + """Check if user is authenticated.""" + return self.user is not None + + def has_permission(self, resource: str, action: str) -> bool: + """Check if context has permission for resource:action. + + Args: + resource: Resource name (e.g., "threads") + action: Action name (e.g., "read") + + Returns: + True if user has permission + """ + permission = f"{resource}:{action}" + return permission in self.permissions + + def require_user(self) -> User: + """Get user or raise 401. + + Raises: + HTTPException 401 if not authenticated + """ + if not self.user: + raise HTTPException(status_code=401, detail="Authentication required") + return self.user + + +def get_auth_context(request: Request) -> AuthContext | None: + """Get AuthContext from request state.""" + return getattr(request.state, "auth", None) + + +_ALL_PERMISSIONS: list[str] = [ + Permissions.THREADS_READ, + Permissions.THREADS_WRITE, + Permissions.THREADS_DELETE, + Permissions.RUNS_CREATE, + Permissions.RUNS_READ, + Permissions.RUNS_CANCEL, +] + + +async def _authenticate(request: Request) -> AuthContext: + """Authenticate request and return AuthContext. + + Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline. + Returns AuthContext with user=None for anonymous requests. + """ + from app.gateway.deps import get_optional_user_from_request + + user = await get_optional_user_from_request(request) + if user is None: + return AuthContext(user=None, permissions=[]) + + # In future, permissions could be stored in user record + return AuthContext(user=user, permissions=_ALL_PERMISSIONS) + + +def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]: + """Decorator that authenticates the request and sets AuthContext. + + Must be placed ABOVE other decorators (executes after them). + + Usage: + @router.get("/{thread_id}") + @require_auth # Bottom decorator (executes first after permission check) + @require_permission("threads", "read") + async def get_thread(thread_id: str, request: Request): + auth: AuthContext = request.state.auth + ... + + Raises: + ValueError: If 'request' parameter is missing + """ + + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + request = kwargs.get("request") + if request is None: + raise ValueError("require_auth decorator requires 'request' parameter") + + # Authenticate and set context + auth_context = await _authenticate(request) + request.state.auth = auth_context + + return await func(*args, **kwargs) + + return wrapper + + +def require_permission( + resource: str, + action: str, + owner_check: bool = False, + owner_filter_key: str = "owner_id", + inject_record: bool = False, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + """Decorator that checks permission for resource:action. + + Must be used AFTER @require_auth. + + Args: + resource: Resource name (e.g., "threads", "runs") + action: Action name (e.g., "read", "write", "delete") + owner_check: If True, validates that the current user owns the resource. + Requires 'thread_id' path parameter and performs ownership check. + owner_filter_key: Field name for ownership filter (default: "owner_id") + inject_record: If True and owner_check is True, injects the thread record + into kwargs['thread_record'] for use in the handler. + + Usage: + # Simple permission check + @require_permission("threads", "read") + async def get_thread(thread_id: str, request: Request): + ... + + # With ownership check (for /threads/{thread_id} endpoints) + @require_permission("threads", "delete", owner_check=True) + async def delete_thread(thread_id: str, request: Request): + ... + + # With ownership check and record injection + @require_permission("threads", "delete", owner_check=True, inject_record=True) + async def delete_thread(thread_id: str, request: Request, thread_record: dict = None): + # thread_record is injected if found + ... + + Raises: + HTTPException 401: If authentication required but user is anonymous + HTTPException 403: If user lacks permission + HTTPException 404: If owner_check=True but user doesn't own the thread + ValueError: If owner_check=True but 'thread_id' parameter is missing + """ + + def decorator(func: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + request = kwargs.get("request") + if request is None: + raise ValueError("require_permission decorator requires 'request' parameter") + + auth: AuthContext = getattr(request.state, "auth", None) + if auth is None: + auth = await _authenticate(request) + request.state.auth = auth + + if not auth.is_authenticated: + raise HTTPException(status_code=401, detail="Authentication required") + + # Check permission + if not auth.has_permission(resource, action): + raise HTTPException( + status_code=403, + detail=f"Permission denied: {resource}:{action}", + ) + + # Owner check for thread-specific resources + if owner_check: + thread_id = kwargs.get("thread_id") + if thread_id is None: + raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") + + # Get thread and verify ownership + from app.gateway.routers.threads import _store_get, get_store + + store = get_store(request) + if store is not None: + record = await _store_get(store, thread_id) + if record: + owner_id = record.get("metadata", {}).get(owner_filter_key) + if owner_id and owner_id != str(auth.user.id): + raise HTTPException( + status_code=404, + detail=f"Thread {thread_id} not found", + ) + # Inject record if requested + if inject_record: + kwargs["thread_record"] = record + + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/backend/app/gateway/csrf_middleware.py b/backend/app/gateway/csrf_middleware.py new file mode 100644 index 000000000..fc96878b6 --- /dev/null +++ b/backend/app/gateway/csrf_middleware.py @@ -0,0 +1,112 @@ +"""CSRF protection middleware for FastAPI. + +Per RFC-001: +State-changing operations require CSRF protection. +""" + +import secrets +from collections.abc import Callable + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse +from starlette.types import ASGIApp + +CSRF_COOKIE_NAME = "csrf_token" +CSRF_HEADER_NAME = "X-CSRF-Token" +CSRF_TOKEN_LENGTH = 64 # bytes + + +def is_secure_request(request: Request) -> bool: + """Detect whether the original client request was made over HTTPS.""" + return request.headers.get("x-forwarded-proto", request.url.scheme) == "https" + + +def generate_csrf_token() -> str: + """Generate a secure random CSRF token.""" + return secrets.token_urlsafe(CSRF_TOKEN_LENGTH) + + +def should_check_csrf(request: Request) -> bool: + """Determine if a request needs CSRF validation. + + CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH). + GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231. + """ + if request.method not in ("POST", "PUT", "DELETE", "PATCH"): + return False + + path = request.url.path.rstrip("/") + # Exempt /api/v1/auth/me endpoint + if path == "/api/v1/auth/me": + return False + return True + + +_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset( + { + "/api/v1/auth/login/local", + "/api/v1/auth/logout", + "/api/v1/auth/register", + } +) + + +def is_auth_endpoint(request: Request) -> bool: + """Check if the request is to an auth endpoint. + + Auth endpoints don't need CSRF validation on first call (no token). + """ + return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS + + +class CSRFMiddleware(BaseHTTPMiddleware): + """Middleware that implements CSRF protection using Double Submit Cookie pattern.""" + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + _is_auth = is_auth_endpoint(request) + + if should_check_csrf(request) and not _is_auth: + cookie_token = request.cookies.get(CSRF_COOKIE_NAME) + header_token = request.headers.get(CSRF_HEADER_NAME) + + if not cookie_token or not header_token: + return JSONResponse( + status_code=403, + content={"detail": "CSRF token missing. Include X-CSRF-Token header."}, + ) + + if not secrets.compare_digest(cookie_token, header_token): + return JSONResponse( + status_code=403, + content={"detail": "CSRF token mismatch."}, + ) + + response = await call_next(request) + + # For auth endpoints that set up session, also set CSRF cookie + if _is_auth and request.method == "POST": + # Generate a new CSRF token for the session + csrf_token = generate_csrf_token() + is_https = is_secure_request(request) + response.set_cookie( + key=CSRF_COOKIE_NAME, + value=csrf_token, + httponly=False, # Must be JS-readable for Double Submit Cookie pattern + secure=is_https, + samesite="strict", + ) + + return response + + +def get_csrf_token(request: Request) -> str | None: + """Get the CSRF token from the current request's cookies. + + This is useful for server-side rendering where you need to embed + token in forms or headers. + """ + return request.cookies.get(CSRF_COOKIE_NAME) diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index bdcea365c..b6fa9c975 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -11,11 +11,16 @@ from __future__ import annotations from collections.abc import AsyncGenerator from contextlib import AsyncExitStack, asynccontextmanager +from typing import TYPE_CHECKING from fastapi import FastAPI, HTTPException, Request from deerflow.runtime import RunContext, RunManager +if TYPE_CHECKING: + from app.gateway.auth.local_provider import LocalAuthProvider + from app.gateway.auth.repositories.sqlite import SQLiteUserRepository + @asynccontextmanager async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: @@ -127,10 +132,86 @@ def get_run_context(request: Request) -> RunContext: ) -async def get_current_user(request: Request) -> str | None: - """Extract user identity from request. +# --------------------------------------------------------------------------- +# Auth helpers (used by authz.py and auth middleware) +# --------------------------------------------------------------------------- - Phase 2: always returns None (no authentication). - Phase 3: extract user_id from JWT / session / API key header. +# Cached singletons to avoid repeated instantiation per request +_cached_local_provider: LocalAuthProvider | None = None +_cached_repo: SQLiteUserRepository | None = None + + +def get_local_provider() -> LocalAuthProvider: + """Get or create the cached LocalAuthProvider singleton.""" + global _cached_local_provider, _cached_repo + if _cached_repo is None: + from app.gateway.auth.repositories.sqlite import SQLiteUserRepository + + _cached_repo = SQLiteUserRepository() + if _cached_local_provider is None: + from app.gateway.auth.local_provider import LocalAuthProvider + + _cached_local_provider = LocalAuthProvider(repository=_cached_repo) + return _cached_local_provider + + +async def get_current_user_from_request(request: Request): + """Get the current authenticated user from the request cookie. + + Raises HTTPException 401 if not authenticated. """ - return None + from app.gateway.auth import decode_token + from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code + + access_token = request.cookies.get("access_token") + if not access_token: + raise HTTPException( + status_code=401, + detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(), + ) + + payload = decode_token(access_token) + if isinstance(payload, TokenError): + raise HTTPException( + status_code=401, + detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(), + ) + + provider = get_local_provider() + user = await provider.get_user(payload.sub) + if user is None: + raise HTTPException( + status_code=401, + detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(), + ) + + # Token version mismatch → password was changed, token is stale + if user.token_version != payload.ver: + raise HTTPException( + status_code=401, + detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(), + ) + + return user + + +async def get_optional_user_from_request(request: Request): + """Get optional authenticated user from request. + + Returns None if not authenticated. + """ + try: + return await get_current_user_from_request(request) + except HTTPException: + return None + + +async def get_current_user(request: Request) -> str | None: + """Extract user_id from request cookie, or None if not authenticated. + + Thin adapter that returns the string id for callers that only need + identification (e.g., ``feedback.py``). Full-user callers should use + ``get_current_user_from_request`` or ``get_optional_user_from_request``. + """ + user = await get_optional_user_from_request(request) + return str(user.id) if user else None diff --git a/backend/app/gateway/langgraph_auth.py b/backend/app/gateway/langgraph_auth.py new file mode 100644 index 000000000..25d3b434c --- /dev/null +++ b/backend/app/gateway/langgraph_auth.py @@ -0,0 +1,106 @@ +"""LangGraph Server auth handler — shares JWT logic with Gateway. + +Loaded by LangGraph Server via langgraph.json ``auth.path``. +Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway, +so both modes validate tokens with the same secret and rules. + +Two layers: + 1. @auth.authenticate — validates JWT cookie, extracts user_id, + and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH) + 2. @auth.on — returns metadata filter so each user only sees own threads +""" + +import secrets + +from langgraph_sdk import Auth + +from app.gateway.auth.errors import TokenError +from app.gateway.auth.jwt import decode_token +from app.gateway.deps import get_local_provider + +auth = Auth() + +# Methods that require CSRF validation (state-changing per RFC 7231). +_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"}) + + +def _check_csrf(request) -> None: + """Enforce Double Submit Cookie CSRF check for state-changing requests. + + Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes + proxied directly by nginx have the same CSRF protection. + """ + method = getattr(request, "method", "") or "" + if method.upper() not in _CSRF_METHODS: + return + + cookie_token = request.cookies.get("csrf_token") + header_token = request.headers.get("x-csrf-token") + + if not cookie_token or not header_token: + raise Auth.exceptions.HTTPException( + status_code=403, + detail="CSRF token missing. Include X-CSRF-Token header.", + ) + + if not secrets.compare_digest(cookie_token, header_token): + raise Auth.exceptions.HTTPException( + status_code=403, + detail="CSRF token mismatch.", + ) + + +@auth.authenticate +async def authenticate(request): + """Validate the session cookie, decode JWT, and check token_version. + + Same validation chain as Gateway's get_current_user_from_request: + cookie → decode JWT → DB lookup → token_version match + Also enforces CSRF on state-changing methods. + """ + # CSRF check before authentication so forged cross-site requests + # are rejected early, even if the cookie carries a valid JWT. + _check_csrf(request) + + token = request.cookies.get("access_token") + if not token: + raise Auth.exceptions.HTTPException( + status_code=401, + detail="Not authenticated", + ) + + payload = decode_token(token) + if isinstance(payload, TokenError): + raise Auth.exceptions.HTTPException( + status_code=401, + detail=f"Token error: {payload.value}", + ) + + user = await get_local_provider().get_user(payload.sub) + if user is None: + raise Auth.exceptions.HTTPException( + status_code=401, + detail="User not found", + ) + if user.token_version != payload.ver: + raise Auth.exceptions.HTTPException( + status_code=401, + detail="Token revoked (password changed)", + ) + + return payload.sub + + +@auth.on +async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict): + """Inject owner_id metadata on writes; filter by owner_id on reads. + + Gateway stores thread ownership as ``metadata.owner_id``. + This handler ensures LangGraph Server enforces the same isolation. + """ + # On create/update: stamp owner_id into metadata + metadata = value.setdefault("metadata", {}) + metadata["owner_id"] = ctx.user.identity + + # Return filter dict — LangGraph applies it to search/read/delete + return {"owner_id": ctx.user.identity} diff --git a/backend/app/gateway/routers/auth.py b/backend/app/gateway/routers/auth.py new file mode 100644 index 000000000..843dd7185 --- /dev/null +++ b/backend/app/gateway/routers/auth.py @@ -0,0 +1,303 @@ +"""Authentication endpoints.""" + +import logging +import time + +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi.security import OAuth2PasswordRequestForm +from pydantic import BaseModel, EmailStr, Field + +from app.gateway.auth import ( + UserResponse, + create_access_token, +) +from app.gateway.auth.config import get_auth_config +from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse +from app.gateway.csrf_middleware import is_secure_request +from app.gateway.deps import get_current_user_from_request, get_local_provider + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) + + +# ── Request/Response Models ────────────────────────────────────────────── + + +class LoginResponse(BaseModel): + """Response model for login — token only lives in HttpOnly cookie.""" + + expires_in: int # seconds + needs_setup: bool = False + + +class RegisterRequest(BaseModel): + """Request model for user registration.""" + + email: EmailStr + password: str = Field(..., min_length=8) + + +class ChangePasswordRequest(BaseModel): + """Request model for password change (also handles setup flow).""" + + current_password: str + new_password: str = Field(..., min_length=8) + new_email: EmailStr | None = None + + +class MessageResponse(BaseModel): + """Generic message response.""" + + message: str + + +# ── Helpers ─────────────────────────────────────────────────────────────── + + +def _set_session_cookie(response: Response, token: str, request: Request) -> None: + """Set the access_token HttpOnly cookie on the response.""" + config = get_auth_config() + is_https = is_secure_request(request) + response.set_cookie( + key="access_token", + value=token, + httponly=True, + secure=is_https, + samesite="lax", + max_age=config.token_expiry_days * 24 * 3600 if is_https else None, + ) + + +# ── Rate Limiting ──────────────────────────────────────────────────────── +# In-process dict — not shared across workers. Sufficient for single-worker deployments. + +_MAX_LOGIN_ATTEMPTS = 5 +_LOCKOUT_SECONDS = 300 # 5 minutes + +# ip → (fail_count, lock_until_timestamp) +_login_attempts: dict[str, tuple[int, float]] = {} + + +def _get_client_ip(request: Request) -> str: + """Extract the real client IP for rate limiting. + + Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP + $remote_addr``). Nginx unconditionally overwrites any client-supplied + ``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP + that nginx observed — it cannot be spoofed by the client. + + ``request.client.host`` is NOT reliable because uvicorn's default + ``proxy_headers=True`` replaces it with the *first* entry from + ``X-Forwarded-For``, which IS client-spoofable. + + ``X-Forwarded-For`` is intentionally NOT used for the same reason. + """ + real_ip = request.headers.get("x-real-ip", "").strip() + if real_ip: + return real_ip + + # Fallback: direct connection without nginx (e.g. unit tests, dev). + return request.client.host if request.client else "unknown" + + +def _check_rate_limit(ip: str) -> None: + """Raise 429 if the IP is currently locked out.""" + record = _login_attempts.get(ip) + if record is None: + return + fail_count, lock_until = record + if fail_count >= _MAX_LOGIN_ATTEMPTS: + if time.time() < lock_until: + raise HTTPException( + status_code=429, + detail="Too many login attempts. Try again later.", + ) + del _login_attempts[ip] + + +_MAX_TRACKED_IPS = 10000 + + +def _record_login_failure(ip: str) -> None: + """Record a failed login attempt for the given IP.""" + # Evict expired lockouts when dict grows too large + if len(_login_attempts) >= _MAX_TRACKED_IPS: + now = time.time() + expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t] + for k in expired: + del _login_attempts[k] + # If still too large, evict cheapest-to-lose half: below-threshold + # IPs (lock_until=0.0) sort first, then earliest-expiring lockouts. + if len(_login_attempts) >= _MAX_TRACKED_IPS: + by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1]) + for k, _ in by_time[: len(by_time) // 2]: + del _login_attempts[k] + + record = _login_attempts.get(ip) + if record is None: + _login_attempts[ip] = (1, 0.0) + else: + new_count = record[0] + 1 + lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0 + _login_attempts[ip] = (new_count, lock_until) + + +def _record_login_success(ip: str) -> None: + """Clear failure counter for the given IP on successful login.""" + _login_attempts.pop(ip, None) + + +# ── Endpoints ───────────────────────────────────────────────────────────── + + +@router.post("/login/local", response_model=LoginResponse) +async def login_local( + request: Request, + response: Response, + form_data: OAuth2PasswordRequestForm = Depends(), +): + """Local email/password login.""" + client_ip = _get_client_ip(request) + _check_rate_limit(client_ip) + + user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password}) + + if user is None: + _record_login_failure(client_ip) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(), + ) + + _record_login_success(client_ip) + token = create_access_token(str(user.id), token_version=user.token_version) + _set_session_cookie(response, token, request) + + return LoginResponse( + expires_in=get_auth_config().token_expiry_days * 24 * 3600, + needs_setup=user.needs_setup, + ) + + +@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) +async def register(request: Request, response: Response, body: RegisterRequest): + """Register a new user account (always 'user' role). + + Admin is auto-created on first boot. This endpoint creates regular users. + Auto-login by setting the session cookie. + """ + try: + user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user") + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(), + ) + + token = create_access_token(str(user.id), token_version=user.token_version) + _set_session_cookie(response, token, request) + + return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role) + + +@router.post("/logout", response_model=MessageResponse) +async def logout(request: Request, response: Response): + """Logout current user by clearing the cookie.""" + response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax") + return MessageResponse(message="Successfully logged out") + + +@router.post("/change-password", response_model=MessageResponse) +async def change_password(request: Request, response: Response, body: ChangePasswordRequest): + """Change password for the currently authenticated user. + + Also handles the first-boot setup flow: + - If new_email is provided, updates email (checks uniqueness) + - If user.needs_setup is True and new_email is given, clears needs_setup + - Always increments token_version to invalidate old sessions + - Re-issues session cookie with new token_version + """ + from app.gateway.auth.password import hash_password_async, verify_password_async + + user = await get_current_user_from_request(request) + + if user.password_hash is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump()) + + if not await verify_password_async(body.current_password, user.password_hash): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump()) + + provider = get_local_provider() + + # Update email if provided + if body.new_email is not None: + existing = await provider.get_user_by_email(body.new_email) + if existing and str(existing.id) != str(user.id): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump()) + user.email = body.new_email + + # Update password + bump version + user.password_hash = await hash_password_async(body.new_password) + user.token_version += 1 + + # Clear setup flag if this is the setup flow + if user.needs_setup and body.new_email is not None: + user.needs_setup = False + + await provider.update_user(user) + + # Re-issue cookie with new token_version + token = create_access_token(str(user.id), token_version=user.token_version) + _set_session_cookie(response, token, request) + + return MessageResponse(message="Password changed successfully") + + +@router.get("/me", response_model=UserResponse) +async def get_me(request: Request): + """Get current authenticated user info.""" + user = await get_current_user_from_request(request) + return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup) + + +@router.get("/setup-status") +async def setup_status(): + """Check if admin account exists. Always False after first boot.""" + user_count = await get_local_provider().count_users() + return {"needs_setup": user_count == 0} + + +# ── OAuth Endpoints (Future/Placeholder) ───────────────────────────────── + + +@router.get("/oauth/{provider}") +async def oauth_login(provider: str): + """Initiate OAuth login flow. + + Redirects to the OAuth provider's authorization URL. + Currently a placeholder - requires OAuth provider implementation. + """ + if provider not in ["github", "google"]: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported OAuth provider: {provider}", + ) + + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="OAuth login not yet implemented", + ) + + +@router.get("/callback/{provider}") +async def oauth_callback(provider: str, code: str, state: str): + """OAuth callback endpoint. + + Handles the OAuth provider's callback after user authorization. + Currently a placeholder. + """ + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail="OAuth callback not yet implemented", + ) diff --git a/backend/langgraph.json b/backend/langgraph.json index 74f5c691d..28588c9f8 100644 --- a/backend/langgraph.json +++ b/backend/langgraph.json @@ -8,6 +8,9 @@ "graphs": { "lead_agent": "deerflow.agents:make_lead_agent" }, + "auth": { + "path": "./app/gateway/langgraph_auth.py:auth" + }, "checkpointer": { "path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer" } diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py new file mode 100644 index 000000000..d73a6925f --- /dev/null +++ b/backend/tests/test_auth.py @@ -0,0 +1,506 @@ +"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators.""" + +from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient + +from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password +from app.gateway.auth.models import User +from app.gateway.authz import ( + AuthContext, + Permissions, + get_auth_context, + require_auth, + require_permission, +) + +# ── Password Hashing ──────────────────────────────────────────────────────── + + +def test_hash_password_and_verify(): + """Hashing and verification round-trip.""" + password = "s3cr3tP@ssw0rd!" + hashed = hash_password(password) + assert hashed != password + assert verify_password(password, hashed) is True + assert verify_password("wrongpassword", hashed) is False + + +def test_hash_password_different_each_time(): + """bcrypt generates unique salts, so same password has different hashes.""" + password = "testpassword" + h1 = hash_password(password) + h2 = hash_password(password) + assert h1 != h2 # Different salts + # But both verify correctly + assert verify_password(password, h1) is True + assert verify_password(password, h2) is True + + +def test_verify_password_rejects_empty(): + """Empty password should not verify.""" + hashed = hash_password("nonempty") + assert verify_password("", hashed) is False + + +# ── JWT ───────────────────────────────────────────────────────────────────── + + +def test_create_and_decode_token(): + """JWT creation and decoding round-trip.""" + user_id = str(uuid4()) + # Set a valid JWT secret for this test + import os + + os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars" + token = create_access_token(user_id) + assert isinstance(token, str) + + payload = decode_token(token) + assert payload is not None + assert payload.sub == user_id + + +def test_decode_token_expired(): + """Expired token returns TokenError.EXPIRED.""" + from app.gateway.auth.errors import TokenError + + user_id = str(uuid4()) + # Create token that expires immediately + token = create_access_token(user_id, expires_delta=timedelta(seconds=-1)) + payload = decode_token(token) + assert payload == TokenError.EXPIRED + + +def test_decode_token_invalid(): + """Invalid token returns TokenError.""" + from app.gateway.auth.errors import TokenError + + assert isinstance(decode_token("not.a.valid.token"), TokenError) + assert isinstance(decode_token(""), TokenError) + assert isinstance(decode_token("completely-wrong"), TokenError) + + +def test_create_token_custom_expiry(): + """Custom expiry is respected.""" + user_id = str(uuid4()) + token = create_access_token(user_id, expires_delta=timedelta(hours=1)) + payload = decode_token(token) + assert payload is not None + assert payload.sub == user_id + + +# ── AuthContext ──────────────────────────────────────────────────────────── + + +def test_auth_context_unauthenticated(): + """AuthContext with no user.""" + ctx = AuthContext(user=None, permissions=[]) + assert ctx.is_authenticated is False + assert ctx.has_permission("threads", "read") is False + + +def test_auth_context_authenticated_no_perms(): + """AuthContext with user but no permissions.""" + user = User(id=uuid4(), email="test@example.com", password_hash="hash") + ctx = AuthContext(user=user, permissions=[]) + assert ctx.is_authenticated is True + assert ctx.has_permission("threads", "read") is False + + +def test_auth_context_has_permission(): + """AuthContext permission checking.""" + user = User(id=uuid4(), email="test@example.com", password_hash="hash") + perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE] + ctx = AuthContext(user=user, permissions=perms) + assert ctx.has_permission("threads", "read") is True + assert ctx.has_permission("threads", "write") is True + assert ctx.has_permission("threads", "delete") is False + assert ctx.has_permission("runs", "read") is False + + +def test_auth_context_require_user_raises(): + """require_user raises 401 when not authenticated.""" + ctx = AuthContext(user=None, permissions=[]) + with pytest.raises(HTTPException) as exc_info: + ctx.require_user() + assert exc_info.value.status_code == 401 + + +def test_auth_context_require_user_returns_user(): + """require_user returns user when authenticated.""" + user = User(id=uuid4(), email="test@example.com", password_hash="hash") + ctx = AuthContext(user=user, permissions=[]) + returned = ctx.require_user() + assert returned == user + + +# ── get_auth_context helper ───────────────────────────────────────────────── + + +def test_get_auth_context_not_set(): + """get_auth_context returns None when auth not set on request.""" + mock_request = MagicMock() + # Make getattr return None (simulating attribute not set) + mock_request.state = MagicMock() + del mock_request.state.auth + assert get_auth_context(mock_request) is None + + +def test_get_auth_context_set(): + """get_auth_context returns the AuthContext from request.""" + user = User(id=uuid4(), email="test@example.com", password_hash="hash") + ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ]) + + mock_request = MagicMock() + mock_request.state.auth = ctx + + assert get_auth_context(mock_request) == ctx + + +# ── require_auth decorator ────────────────────────────────────────────────── + + +def test_require_auth_sets_auth_context(): + """require_auth sets auth context on request from cookie.""" + from fastapi import Request + + app = FastAPI() + + @app.get("/test") + @require_auth + async def endpoint(request: Request): + ctx = get_auth_context(request) + return {"authenticated": ctx.is_authenticated} + + with TestClient(app) as client: + # No cookie → anonymous + response = client.get("/test") + assert response.status_code == 200 + assert response.json()["authenticated"] is False + + +def test_require_auth_requires_request_param(): + """require_auth raises ValueError if request parameter is missing.""" + import asyncio + + @require_auth + async def bad_endpoint(): # Missing `request` parameter + pass + + with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"): + asyncio.run(bad_endpoint()) + + +# ── require_permission decorator ───────────────────────────────────────────── + + +def test_require_permission_requires_auth(): + """require_permission raises 401 when not authenticated.""" + from fastapi import Request + + app = FastAPI() + + @app.get("/test") + @require_permission("threads", "read") + async def endpoint(request: Request): + return {"ok": True} + + with TestClient(app) as client: + response = client.get("/test") + assert response.status_code == 401 + assert "Authentication required" in response.json()["detail"] + + +def test_require_permission_denies_wrong_permission(): + """User without required permission gets 403.""" + from fastapi import Request + + app = FastAPI() + user = User(id=uuid4(), email="test@example.com", password_hash="hash") + + @app.get("/test") + @require_permission("threads", "delete") + async def endpoint(request: Request): + return {"ok": True} + + mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ]) + + with patch("app.gateway.authz._authenticate", return_value=mock_auth): + with TestClient(app) as client: + response = client.get("/test") + assert response.status_code == 403 + assert "Permission denied" in response.json()["detail"] + + +# ── Weak JWT secret warning ────────────────────────────────────────────────── + + +# ── User Model Fields ────────────────────────────────────────────────────── + + +def test_user_model_has_needs_setup_default_false(): + """New users default to needs_setup=False.""" + user = User(email="test@example.com", password_hash="hash") + assert user.needs_setup is False + + +def test_user_model_has_token_version_default_zero(): + """New users default to token_version=0.""" + user = User(email="test@example.com", password_hash="hash") + assert user.token_version == 0 + + +def test_user_model_needs_setup_true(): + """Auto-created admin has needs_setup=True.""" + user = User(email="admin@example.com", password_hash="hash", needs_setup=True) + assert user.needs_setup is True + + +def test_sqlite_round_trip_new_fields(): + """needs_setup and token_version survive create → read round-trip.""" + import asyncio + import os + import tempfile + from pathlib import Path + + from app.gateway.auth.repositories import sqlite as sqlite_mod + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_users.db") + old_path = sqlite_mod._resolved_db_path + old_init = sqlite_mod._table_initialized + sqlite_mod._resolved_db_path = Path(db_path) + sqlite_mod._table_initialized = False + try: + repo = sqlite_mod.SQLiteUserRepository() + user = User( + email="setup@test.com", + password_hash="fakehash", + system_role="admin", + needs_setup=True, + token_version=3, + ) + created = asyncio.run(repo.create_user(user)) + assert created.needs_setup is True + assert created.token_version == 3 + + fetched = asyncio.run(repo.get_user_by_email("setup@test.com")) + assert fetched is not None + assert fetched.needs_setup is True + assert fetched.token_version == 3 + + fetched.needs_setup = False + fetched.token_version = 4 + asyncio.run(repo.update_user(fetched)) + refetched = asyncio.run(repo.get_user_by_id(str(fetched.id))) + assert refetched.needs_setup is False + assert refetched.token_version == 4 + finally: + sqlite_mod._resolved_db_path = old_path + sqlite_mod._table_initialized = old_init + + +# ── Token Versioning ─────────────────────────────────────────────────────── + + +def test_jwt_encodes_ver(): + """JWT payload includes ver field.""" + import os + + from app.gateway.auth.errors import TokenError + + os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars" + token = create_access_token(str(uuid4()), token_version=3) + payload = decode_token(token) + assert not isinstance(payload, TokenError) + assert payload.ver == 3 + + +def test_jwt_default_ver_zero(): + """JWT ver defaults to 0.""" + import os + + from app.gateway.auth.errors import TokenError + + os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars" + token = create_access_token(str(uuid4())) + payload = decode_token(token) + assert not isinstance(payload, TokenError) + assert payload.ver == 0 + + +def test_token_version_mismatch_rejects(): + """Token with stale ver is rejected by get_current_user_from_request.""" + import asyncio + import os + + os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars" + + user_id = str(uuid4()) + token = create_access_token(user_id, token_version=0) + + mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1) + + mock_request = MagicMock() + mock_request.cookies = {"access_token": token} + + with patch("app.gateway.deps.get_local_provider") as mock_provider_fn: + mock_provider = MagicMock() + mock_provider.get_user = AsyncMock(return_value=mock_user) + mock_provider_fn.return_value = mock_provider + + from app.gateway.deps import get_current_user_from_request + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(get_current_user_from_request(mock_request)) + assert exc_info.value.status_code == 401 + assert "revoked" in str(exc_info.value.detail).lower() + + +# ── change-password extension ────────────────────────────────────────────── + + +def test_change_password_request_accepts_new_email(): + """ChangePasswordRequest model accepts optional new_email.""" + from app.gateway.routers.auth import ChangePasswordRequest + + req = ChangePasswordRequest( + current_password="old", + new_password="newpassword", + new_email="new@example.com", + ) + assert req.new_email == "new@example.com" + + +def test_change_password_request_new_email_optional(): + """ChangePasswordRequest model works without new_email.""" + from app.gateway.routers.auth import ChangePasswordRequest + + req = ChangePasswordRequest(current_password="old", new_password="newpassword") + assert req.new_email is None + + +def test_login_response_includes_needs_setup(): + """LoginResponse includes needs_setup field.""" + from app.gateway.routers.auth import LoginResponse + + resp = LoginResponse(expires_in=3600, needs_setup=True) + assert resp.needs_setup is True + resp2 = LoginResponse(expires_in=3600) + assert resp2.needs_setup is False + + +# ── Rate Limiting ────────────────────────────────────────────────────────── + + +def test_rate_limiter_allows_under_limit(): + """Requests under the limit are allowed.""" + from app.gateway.routers.auth import _check_rate_limit, _login_attempts + + _login_attempts.clear() + _check_rate_limit("192.168.1.1") # Should not raise + + +def test_rate_limiter_blocks_after_max_failures(): + """IP is blocked after 5 consecutive failures.""" + from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure + + _login_attempts.clear() + ip = "10.0.0.1" + for _ in range(5): + _record_login_failure(ip) + with pytest.raises(HTTPException) as exc_info: + _check_rate_limit(ip) + assert exc_info.value.status_code == 429 + + +def test_rate_limiter_resets_on_success(): + """Successful login clears the failure counter.""" + from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success + + _login_attempts.clear() + ip = "10.0.0.2" + for _ in range(4): + _record_login_failure(ip) + _record_login_success(ip) + _check_rate_limit(ip) # Should not raise + + +# ── Client IP extraction ───────────────────────────────────────────────── + + +def test_get_client_ip_direct_connection(): + """Without nginx (no X-Real-IP), falls back to request.client.host.""" + from app.gateway.routers.auth import _get_client_ip + + req = MagicMock() + req.client.host = "203.0.113.42" + req.headers = {} + assert _get_client_ip(req) == "203.0.113.42" + + +def test_get_client_ip_uses_x_real_ip(): + """X-Real-IP (set by nginx) is used when present.""" + from app.gateway.routers.auth import _get_client_ip + + req = MagicMock() + req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0] + req.headers = {"x-real-ip": "203.0.113.42"} + assert _get_client_ip(req) == "203.0.113.42" + + +def test_get_client_ip_xff_ignored(): + """X-Forwarded-For is never used; only X-Real-IP matters.""" + from app.gateway.routers.auth import _get_client_ip + + req = MagicMock() + req.client.host = "10.0.0.1" + req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"} + assert _get_client_ip(req) == "198.51.100.5" + + +def test_get_client_ip_no_real_ip_fallback(): + """No X-Real-IP → falls back to client.host (direct connection).""" + from app.gateway.routers.auth import _get_client_ip + + req = MagicMock() + req.client.host = "127.0.0.1" + req.headers = {} + assert _get_client_ip(req) == "127.0.0.1" + + +def test_get_client_ip_x_real_ip_always_preferred(): + """X-Real-IP is always preferred over client.host regardless of IP.""" + from app.gateway.routers.auth import _get_client_ip + + req = MagicMock() + req.client.host = "203.0.113.99" + req.headers = {"x-real-ip": "198.51.100.7"} + assert _get_client_ip(req) == "198.51.100.7" + + +# ── Weak JWT secret warning ────────────────────────────────────────────────── + + +def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog): + """get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset.""" + import logging + + import app.gateway.auth.config as config_module + + config_module._auth_config = None + monkeypatch.delenv("AUTH_JWT_SECRET", raising=False) + + with caplog.at_level(logging.WARNING): + config = config_module.get_auth_config() + + assert config.jwt_secret # non-empty ephemeral secret + assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages) + + # Cleanup + config_module._auth_config = None diff --git a/backend/tests/test_auth_middleware.py b/backend/tests/test_auth_middleware.py new file mode 100644 index 000000000..64f8604f0 --- /dev/null +++ b/backend/tests/test_auth_middleware.py @@ -0,0 +1,216 @@ +"""Tests for the global AuthMiddleware (fail-closed safety net).""" + +import pytest +from starlette.testclient import TestClient + +from app.gateway.auth_middleware import AuthMiddleware, _is_public + +# ── _is_public unit tests ───────────────────────────────────────────────── + + +@pytest.mark.parametrize( + "path", + [ + "/health", + "/health/", + "/docs", + "/docs/", + "/redoc", + "/openapi.json", + "/api/v1/auth/login/local", + "/api/v1/auth/register", + "/api/v1/auth/logout", + "/api/v1/auth/setup-status", + ], +) +def test_public_paths(path: str): + assert _is_public(path) is True + + +@pytest.mark.parametrize( + "path", + [ + "/api/models", + "/api/mcp/config", + "/api/memory", + "/api/skills", + "/api/threads/123", + "/api/threads/123/uploads", + "/api/agents", + "/api/channels", + "/api/runs/stream", + "/api/threads/123/runs", + "/api/v1/auth/me", + "/api/v1/auth/change-password", + ], +) +def test_protected_paths(path: str): + assert _is_public(path) is False + + +# ── Trailing slash / normalization edge cases ───────────────────────────── + + +@pytest.mark.parametrize( + "path", + [ + "/api/v1/auth/login/local/", + "/api/v1/auth/register/", + "/api/v1/auth/logout/", + "/api/v1/auth/setup-status/", + ], +) +def test_public_auth_paths_with_trailing_slash(path: str): + assert _is_public(path) is True + + +@pytest.mark.parametrize( + "path", + [ + "/api/models/", + "/api/v1/auth/me/", + "/api/v1/auth/change-password/", + ], +) +def test_protected_paths_with_trailing_slash(path: str): + assert _is_public(path) is False + + +def test_unknown_api_path_is_protected(): + """Fail-closed: any new /api/* path is protected by default.""" + assert _is_public("/api/new-feature") is False + assert _is_public("/api/v2/something") is False + assert _is_public("/api/v1/auth/new-endpoint") is False + + +# ── Middleware integration tests ────────────────────────────────────────── + + +def _make_app(): + """Create a minimal FastAPI app with AuthMiddleware for testing.""" + from fastapi import FastAPI + + app = FastAPI() + app.add_middleware(AuthMiddleware) + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.get("/api/v1/auth/me") + async def auth_me(): + return {"id": "1", "email": "test@test.com"} + + @app.get("/api/v1/auth/setup-status") + async def setup_status(): + return {"needs_setup": False} + + @app.get("/api/models") + async def models_get(): + return {"models": []} + + @app.put("/api/mcp/config") + async def mcp_put(): + return {"ok": True} + + @app.delete("/api/threads/abc") + async def thread_delete(): + return {"ok": True} + + @app.patch("/api/threads/abc") + async def thread_patch(): + return {"ok": True} + + @app.post("/api/threads/abc/runs/stream") + async def stream(): + return {"ok": True} + + @app.get("/api/future-endpoint") + async def future(): + return {"ok": True} + + return app + + +@pytest.fixture +def client(): + return TestClient(_make_app()) + + +def test_public_path_no_cookie(client): + res = client.get("/health") + assert res.status_code == 200 + + +def test_public_auth_path_no_cookie(client): + """Public auth endpoints (login/register) pass without cookie.""" + res = client.get("/api/v1/auth/setup-status") + assert res.status_code == 200 + + +def test_protected_auth_path_no_cookie(client): + """/auth/me requires cookie even though it's under /api/v1/auth/.""" + res = client.get("/api/v1/auth/me") + assert res.status_code == 401 + + +def test_protected_path_no_cookie_returns_401(client): + res = client.get("/api/models") + assert res.status_code == 401 + body = res.json() + assert body["detail"]["code"] == "not_authenticated" + + +def test_protected_path_with_cookie_passes(client): + res = client.get("/api/models", cookies={"access_token": "some-token"}) + assert res.status_code == 200 + + +def test_protected_post_no_cookie_returns_401(client): + res = client.post("/api/threads/abc/runs/stream") + assert res.status_code == 401 + + +# ── Method matrix: PUT/DELETE/PATCH also protected ──────────────────────── + + +def test_protected_put_no_cookie(client): + res = client.put("/api/mcp/config") + assert res.status_code == 401 + + +def test_protected_delete_no_cookie(client): + res = client.delete("/api/threads/abc") + assert res.status_code == 401 + + +def test_protected_patch_no_cookie(client): + res = client.patch("/api/threads/abc") + assert res.status_code == 401 + + +def test_put_with_cookie_passes(client): + client.cookies.set("access_token", "tok") + res = client.put("/api/mcp/config") + assert res.status_code == 200 + + +def test_delete_with_cookie_passes(client): + client.cookies.set("access_token", "tok") + res = client.delete("/api/threads/abc") + assert res.status_code == 200 + + +# ── Fail-closed: unknown future endpoints ───────────────────────────────── + + +def test_unknown_endpoint_no_cookie_returns_401(client): + """Any new /api/* endpoint is blocked by default without cookie.""" + res = client.get("/api/future-endpoint") + assert res.status_code == 401 + + +def test_unknown_endpoint_with_cookie_passes(client): + client.cookies.set("access_token", "tok") + res = client.get("/api/future-endpoint") + assert res.status_code == 200 diff --git a/backend/tests/test_auth_type_system.py b/backend/tests/test_auth_type_system.py new file mode 100644 index 000000000..18b4542d0 --- /dev/null +++ b/backend/tests/test_auth_type_system.py @@ -0,0 +1,675 @@ +"""Tests for auth type system hardening. + +Covers structured error responses, typed decode_token callers, +CSRF middleware path matching, config-driven cookie security, +and unhappy paths / edge cases for all auth boundaries. +""" + +import os +import secrets +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import jwt as pyjwt +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from pydantic import ValidationError + +from app.gateway.auth.config import AuthConfig, set_auth_config +from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError +from app.gateway.auth.jwt import decode_token +from app.gateway.csrf_middleware import ( + CSRF_COOKIE_NAME, + CSRF_HEADER_NAME, + CSRFMiddleware, + is_auth_endpoint, + should_check_csrf, +) + +# ── Setup ──────────────────────────────────────────────────────────── + +_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32" + + +def _setup_config(): + set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) + + +# ── CSRF Middleware Path Matching ──────────────────────────────────── + + +class _FakeRequest: + """Minimal request mock for CSRF path matching tests.""" + + def __init__(self, path: str, method: str = "POST"): + self.method = method + + class _URL: + def __init__(self, p): + self.path = p + + self.url = _URL(path) + self.cookies = {} + self.headers = {} + + +def test_csrf_exempts_login_local(): + """login/local (actual route) should be exempt from CSRF.""" + req = _FakeRequest("/api/v1/auth/login/local") + assert is_auth_endpoint(req) is True + + +def test_csrf_exempts_login_local_trailing_slash(): + """Trailing slash should also be exempt.""" + req = _FakeRequest("/api/v1/auth/login/local/") + assert is_auth_endpoint(req) is True + + +def test_csrf_exempts_logout(): + req = _FakeRequest("/api/v1/auth/logout") + assert is_auth_endpoint(req) is True + + +def test_csrf_exempts_register(): + req = _FakeRequest("/api/v1/auth/register") + assert is_auth_endpoint(req) is True + + +def test_csrf_does_not_exempt_old_login_path(): + """Old /api/v1/auth/login (without /local) should NOT be exempt.""" + req = _FakeRequest("/api/v1/auth/login") + assert is_auth_endpoint(req) is False + + +def test_csrf_does_not_exempt_me(): + req = _FakeRequest("/api/v1/auth/me") + assert is_auth_endpoint(req) is False + + +def test_csrf_skips_get_requests(): + req = _FakeRequest("/api/v1/auth/me", method="GET") + assert should_check_csrf(req) is False + + +def test_csrf_checks_post_to_protected(): + req = _FakeRequest("/api/v1/some/endpoint", method="POST") + assert should_check_csrf(req) is True + + +# ── Structured Error Response Format ──────────────────────────────── + + +def test_auth_error_response_has_code_and_message(): + """All auth errors should have structured {code, message} format.""" + err = AuthErrorResponse( + code=AuthErrorCode.INVALID_CREDENTIALS, + message="Wrong password", + ) + d = err.model_dump() + assert "code" in d + assert "message" in d + assert d["code"] == "invalid_credentials" + + +def test_auth_error_response_all_codes_serializable(): + """Every AuthErrorCode should be serializable in AuthErrorResponse.""" + for code in AuthErrorCode: + err = AuthErrorResponse(code=code, message=f"Test {code.value}") + d = err.model_dump() + assert d["code"] == code.value + + +# ── decode_token Caller Pattern ────────────────────────────────────── + + +def test_decode_token_expired_maps_to_token_expired_code(): + """TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED.""" + _setup_config() + from datetime import UTC, datetime, timedelta + + import jwt as pyjwt + + expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)} + token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256") + result = decode_token(token) + assert result == TokenError.EXPIRED + + # Verify the mapping pattern used in route handlers + code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID + assert code == AuthErrorCode.TOKEN_EXPIRED + + +def test_decode_token_invalid_sig_maps_to_token_invalid_code(): + """TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID.""" + _setup_config() + from datetime import UTC, datetime, timedelta + + import jwt as pyjwt + + payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)} + token = pyjwt.encode(payload, "wrong-key", algorithm="HS256") + result = decode_token(token) + assert result == TokenError.INVALID_SIGNATURE + + code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID + assert code == AuthErrorCode.TOKEN_INVALID + + +def test_decode_token_malformed_maps_to_token_invalid_code(): + """TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID.""" + _setup_config() + result = decode_token("garbage") + assert result == TokenError.MALFORMED + + code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID + assert code == AuthErrorCode.TOKEN_INVALID + + +# ── Login Response Format ──────────────────────────────────────────── + + +def test_login_response_model_has_no_access_token(): + """LoginResponse should NOT contain access_token field (RFC-001).""" + from app.gateway.routers.auth import LoginResponse + + resp = LoginResponse(expires_in=604800) + d = resp.model_dump() + assert "access_token" not in d + assert "expires_in" in d + assert d["expires_in"] == 604800 + + +def test_login_response_model_fields(): + """LoginResponse has expires_in and needs_setup.""" + from app.gateway.routers.auth import LoginResponse + + fields = set(LoginResponse.model_fields.keys()) + assert fields == {"expires_in", "needs_setup"} + + +# ── AuthConfig in Route ────────────────────────────────────────────── + + +def test_auth_config_token_expiry_used_in_login_response(): + """LoginResponse.expires_in should come from config.token_expiry_days.""" + from app.gateway.routers.auth import LoginResponse + + expected_seconds = 14 * 24 * 3600 + resp = LoginResponse(expires_in=expected_seconds) + assert resp.expires_in == expected_seconds + + +# ── UserResponse Type Preservation ─────────────────────────────────── + + +def test_user_response_system_role_literal(): + """UserResponse.system_role should only accept 'admin' or 'user'.""" + from app.gateway.auth.models import UserResponse + + # Valid roles + resp = UserResponse(id="1", email="a@b.com", system_role="admin") + assert resp.system_role == "admin" + + resp = UserResponse(id="1", email="a@b.com", system_role="user") + assert resp.system_role == "user" + + +def test_user_response_rejects_invalid_role(): + """UserResponse should reject invalid system_role values.""" + from app.gateway.auth.models import UserResponse + + with pytest.raises(ValidationError): + UserResponse(id="1", email="a@b.com", system_role="superadmin") + + +# ══════════════════════════════════════════════════════════════════════ +# UNHAPPY PATHS / EDGE CASES +# ══════════════════════════════════════════════════════════════════════ + + +# ── get_current_user structured 401 responses ──────────────────────── + + +def test_get_current_user_no_cookie_returns_not_authenticated(): + """No cookie → 401 with code=not_authenticated.""" + import asyncio + + from fastapi import HTTPException + + from app.gateway.deps import get_current_user_from_request + + mock_request = type("MockRequest", (), {"cookies": {}})() + with pytest.raises(HTTPException) as exc_info: + asyncio.run(get_current_user_from_request(mock_request)) + assert exc_info.value.status_code == 401 + detail = exc_info.value.detail + assert detail["code"] == "not_authenticated" + + +def test_get_current_user_expired_token_returns_token_expired(): + """Expired token → 401 with code=token_expired.""" + import asyncio + + from fastapi import HTTPException + + from app.gateway.deps import get_current_user_from_request + + _setup_config() + expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)} + token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256") + + mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})() + with pytest.raises(HTTPException) as exc_info: + asyncio.run(get_current_user_from_request(mock_request)) + assert exc_info.value.status_code == 401 + detail = exc_info.value.detail + assert detail["code"] == "token_expired" + + +def test_get_current_user_invalid_token_returns_token_invalid(): + """Bad signature → 401 with code=token_invalid.""" + import asyncio + + from fastapi import HTTPException + + from app.gateway.deps import get_current_user_from_request + + _setup_config() + payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)} + token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256") + + mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})() + with pytest.raises(HTTPException) as exc_info: + asyncio.run(get_current_user_from_request(mock_request)) + assert exc_info.value.status_code == 401 + detail = exc_info.value.detail + assert detail["code"] == "token_invalid" + + +def test_get_current_user_malformed_token_returns_token_invalid(): + """Garbage token → 401 with code=token_invalid.""" + import asyncio + + from fastapi import HTTPException + + from app.gateway.deps import get_current_user_from_request + + _setup_config() + mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})() + with pytest.raises(HTTPException) as exc_info: + asyncio.run(get_current_user_from_request(mock_request)) + assert exc_info.value.status_code == 401 + detail = exc_info.value.detail + assert detail["code"] == "token_invalid" + + +# ── decode_token edge cases ────────────────────────────────────────── + + +def test_decode_token_empty_string_returns_malformed(): + _setup_config() + result = decode_token("") + assert result == TokenError.MALFORMED + + +def test_decode_token_whitespace_returns_malformed(): + _setup_config() + result = decode_token(" ") + assert result == TokenError.MALFORMED + + +# ── AuthConfig validation edge cases ───────────────────────────────── + + +def test_auth_config_missing_jwt_secret_raises(): + """AuthConfig requires jwt_secret — no default allowed.""" + with pytest.raises(ValidationError): + AuthConfig() + + +def test_auth_config_token_expiry_zero_raises(): + """token_expiry_days must be >= 1.""" + with pytest.raises(ValidationError): + AuthConfig(jwt_secret="secret", token_expiry_days=0) + + +def test_auth_config_token_expiry_31_raises(): + """token_expiry_days must be <= 30.""" + with pytest.raises(ValidationError): + AuthConfig(jwt_secret="secret", token_expiry_days=31) + + +def test_auth_config_token_expiry_boundary_1_ok(): + config = AuthConfig(jwt_secret="secret", token_expiry_days=1) + assert config.token_expiry_days == 1 + + +def test_auth_config_token_expiry_boundary_30_ok(): + config = AuthConfig(jwt_secret="secret", token_expiry_days=30) + assert config.token_expiry_days == 30 + + +def test_get_auth_config_missing_env_var_generates_ephemeral(caplog): + """get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset.""" + import logging + + import app.gateway.auth.config as cfg + + old = cfg._auth_config + cfg._auth_config = None + try: + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("AUTH_JWT_SECRET", None) + with caplog.at_level(logging.WARNING): + config = cfg.get_auth_config() + assert config.jwt_secret + assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages) + finally: + cfg._auth_config = old + + +# ── CSRF middleware integration (unhappy paths) ────────────────────── + + +def _make_csrf_app(): + """Create a minimal FastAPI app with CSRFMiddleware for testing.""" + from fastapi import HTTPException as _HTTPException + from fastapi.responses import JSONResponse as _JSONResponse + + app = FastAPI() + + @app.exception_handler(_HTTPException) + async def _http_exc_handler(request, exc): + return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) + + app.add_middleware(CSRFMiddleware) + + @app.post("/api/v1/test/protected") + async def protected(): + return {"ok": True} + + @app.post("/api/v1/auth/login/local") + async def login(): + return {"ok": True} + + @app.get("/api/v1/test/read") + async def read_endpoint(): + return {"ok": True} + + return app + + +def test_csrf_middleware_blocks_post_without_token(): + """POST to protected endpoint without CSRF token → 403 with structured detail.""" + client = TestClient(_make_csrf_app()) + resp = client.post("/api/v1/test/protected") + assert resp.status_code == 403 + assert "CSRF" in resp.json()["detail"] + assert "missing" in resp.json()["detail"].lower() + + +def test_csrf_middleware_blocks_post_with_mismatched_token(): + """POST with mismatched CSRF cookie/header → 403 with mismatch detail.""" + client = TestClient(_make_csrf_app()) + client.cookies.set(CSRF_COOKIE_NAME, "token-a") + resp = client.post( + "/api/v1/test/protected", + headers={CSRF_HEADER_NAME: "token-b"}, + ) + assert resp.status_code == 403 + assert "mismatch" in resp.json()["detail"].lower() + + +def test_csrf_middleware_allows_post_with_matching_token(): + """POST with matching CSRF cookie/header → 200.""" + client = TestClient(_make_csrf_app()) + token = secrets.token_urlsafe(64) + client.cookies.set(CSRF_COOKIE_NAME, token) + resp = client.post( + "/api/v1/test/protected", + headers={CSRF_HEADER_NAME: token}, + ) + assert resp.status_code == 200 + + +def test_csrf_middleware_allows_get_without_token(): + """GET requests bypass CSRF check.""" + client = TestClient(_make_csrf_app()) + resp = client.get("/api/v1/test/read") + assert resp.status_code == 200 + + +def test_csrf_middleware_exempts_login_local(): + """POST to login/local is exempt from CSRF (no token yet).""" + client = TestClient(_make_csrf_app()) + resp = client.post("/api/v1/auth/login/local") + assert resp.status_code == 200 + + +def test_csrf_middleware_sets_cookie_on_auth_endpoint(): + """Auth endpoints should receive a CSRF cookie in response.""" + client = TestClient(_make_csrf_app()) + resp = client.post("/api/v1/auth/login/local") + assert CSRF_COOKIE_NAME in resp.cookies + + +# ── UserResponse edge cases ────────────────────────────────────────── + + +def test_user_response_missing_required_fields(): + """UserResponse with missing fields → ValidationError.""" + from app.gateway.auth.models import UserResponse + + with pytest.raises(ValidationError): + UserResponse(id="1") # missing email, system_role + + with pytest.raises(ValidationError): + UserResponse(id="1", email="a@b.com") # missing system_role + + +def test_user_response_empty_string_role_rejected(): + """Empty string is not a valid role.""" + from app.gateway.auth.models import UserResponse + + with pytest.raises(ValidationError): + UserResponse(id="1", email="a@b.com", system_role="") + + +# ══════════════════════════════════════════════════════════════════════ +# HTTP-LEVEL API CONTRACT TESTS +# ══════════════════════════════════════════════════════════════════════ + + +def _make_auth_app(): + """Create FastAPI app with auth routes for contract testing.""" + from app.gateway.app import create_app + + return create_app() + + +def _get_auth_client(): + """Get TestClient for auth API contract tests.""" + return TestClient(_make_auth_app()) + + +def test_api_auth_me_no_cookie_returns_structured_401(): + """/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}.""" + _setup_config() + client = _get_auth_client() + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 + body = resp.json() + assert body["detail"]["code"] == "not_authenticated" + assert "message" in body["detail"] + + +def test_api_auth_me_expired_token_returns_structured_401(): + """/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}.""" + _setup_config() + expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)} + token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256") + + client = _get_auth_client() + client.cookies.set("access_token", token) + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 + body = resp.json() + assert body["detail"]["code"] == "token_expired" + + +def test_api_auth_me_invalid_sig_returns_structured_401(): + """/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}.""" + _setup_config() + payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)} + token = pyjwt.encode(payload, "wrong-key", algorithm="HS256") + + client = _get_auth_client() + client.cookies.set("access_token", token) + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 + body = resp.json() + assert body["detail"]["code"] == "token_invalid" + + +def test_api_login_bad_credentials_returns_structured_401(): + """Login with wrong password → 401 with {code: 'invalid_credentials'}.""" + _setup_config() + client = _get_auth_client() + resp = client.post( + "/api/v1/auth/login/local", + data={"username": "nonexistent@test.com", "password": "wrongpassword"}, + ) + assert resp.status_code == 401 + body = resp.json() + assert body["detail"]["code"] == "invalid_credentials" + + +def test_api_login_success_no_token_in_body(): + """Successful login → response body has expires_in but NOT access_token.""" + _setup_config() + client = _get_auth_client() + # Register first + client.post( + "/api/v1/auth/register", + json={"email": "contract-test@test.com", "password": "securepassword123"}, + ) + # Login + resp = client.post( + "/api/v1/auth/login/local", + data={"username": "contract-test@test.com", "password": "securepassword123"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "expires_in" in body + assert "access_token" not in body + # Token should be in cookie, not body + assert "access_token" in resp.cookies + + +def test_api_register_duplicate_returns_structured_400(): + """Register with duplicate email → 400 with {code: 'email_already_exists'}.""" + _setup_config() + client = _get_auth_client() + email = "dup-contract-test@test.com" + # First register + client.post("/api/v1/auth/register", json={"email": email, "password": "password123"}) + # Duplicate + resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"}) + assert resp.status_code == 400 + body = resp.json() + assert body["detail"]["code"] == "email_already_exists" + + +# ── Cookie security: HTTP vs HTTPS ──────────────────────────────────── + + +def _unique_email(prefix: str) -> str: + return f"{prefix}-{secrets.token_hex(4)}@test.com" + + +def _get_set_cookie_headers(resp) -> list[str]: + """Extract all set-cookie header values from a TestClient response.""" + return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"] + + +def test_register_http_cookie_httponly_true_secure_false(): + """HTTP register → access_token cookie is httponly=True, secure=False, no max_age.""" + _setup_config() + client = _get_auth_client() + resp = client.post( + "/api/v1/auth/register", + json={"email": _unique_email("http-cookie"), "password": "password123"}, + ) + assert resp.status_code == 201 + cookie_header = resp.headers.get("set-cookie", "") + assert "access_token=" in cookie_header + assert "httponly" in cookie_header.lower() + assert "secure" not in cookie_header.lower().replace("samesite", "") + + +def test_register_https_cookie_httponly_true_secure_true(): + """HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age.""" + _setup_config() + client = _get_auth_client() + resp = client.post( + "/api/v1/auth/register", + json={"email": _unique_email("https-cookie"), "password": "password123"}, + headers={"x-forwarded-proto": "https"}, + ) + assert resp.status_code == 201 + cookie_header = resp.headers.get("set-cookie", "") + assert "access_token=" in cookie_header + assert "httponly" in cookie_header.lower() + assert "secure" in cookie_header.lower() + assert "max-age" in cookie_header.lower() + + +def test_login_https_sets_secure_cookie(): + """HTTPS login → access_token cookie has secure flag.""" + _setup_config() + client = _get_auth_client() + email = _unique_email("https-login") + client.post("/api/v1/auth/register", json={"email": email, "password": "password123"}) + resp = client.post( + "/api/v1/auth/login/local", + data={"username": email, "password": "password123"}, + headers={"x-forwarded-proto": "https"}, + ) + assert resp.status_code == 200 + cookie_header = resp.headers.get("set-cookie", "") + assert "access_token=" in cookie_header + assert "httponly" in cookie_header.lower() + assert "secure" in cookie_header.lower() + + +def test_csrf_cookie_secure_on_https(): + """HTTPS register → csrf_token cookie has secure flag but NOT httponly.""" + _setup_config() + client = _get_auth_client() + resp = client.post( + "/api/v1/auth/register", + json={"email": _unique_email("csrf-https"), "password": "password123"}, + headers={"x-forwarded-proto": "https"}, + ) + assert resp.status_code == 201 + csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h] + assert csrf_cookies, "csrf_token cookie not set on HTTPS register" + csrf_header = csrf_cookies[0] + assert "secure" in csrf_header.lower() + assert "httponly" not in csrf_header.lower() + + +def test_csrf_cookie_not_secure_on_http(): + """HTTP register → csrf_token cookie does NOT have secure flag.""" + _setup_config() + client = _get_auth_client() + resp = client.post( + "/api/v1/auth/register", + json={"email": _unique_email("csrf-http"), "password": "password123"}, + ) + assert resp.status_code == 201 + csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h] + assert csrf_cookies, "csrf_token cookie not set on HTTP register" + csrf_header = csrf_cookies[0] + assert "secure" not in csrf_header.lower().replace("samesite", "") diff --git a/backend/tests/test_ensure_admin.py b/backend/tests/test_ensure_admin.py new file mode 100644 index 000000000..cf6448bcd --- /dev/null +++ b/backend/tests/test_ensure_admin.py @@ -0,0 +1,214 @@ +"""Tests for _ensure_admin_user() in app.py. + +Covers: first-boot admin creation, auto-reset on needs_setup=True, +no-op on needs_setup=False, migration, and edge cases. +""" + +import asyncio +import os +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32") + +from app.gateway.auth.config import AuthConfig, set_auth_config +from app.gateway.auth.models import User + +_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32" + + +@pytest.fixture(autouse=True) +def _setup_auth_config(): + set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET)) + yield + set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET)) + + +def _make_app_stub(store=None): + """Minimal app-like object with state.store.""" + app = SimpleNamespace() + app.state = SimpleNamespace() + app.state.store = store + return app + + +def _make_provider(user_count=0, admin_user=None): + p = AsyncMock() + p.count_users = AsyncMock(return_value=user_count) + p.create_user = AsyncMock( + side_effect=lambda **kw: User( + email=kw["email"], + password_hash="hashed", + system_role=kw.get("system_role", "user"), + needs_setup=kw.get("needs_setup", False), + ) + ) + p.get_user_by_email = AsyncMock(return_value=admin_user) + p.update_user = AsyncMock(side_effect=lambda u: u) + return p + + +# ── First boot: no users ───────────────────────────────────────────────── + + +def test_first_boot_creates_admin(): + """count_users==0 → create admin with needs_setup=True.""" + provider = _make_provider(user_count=0) + app = _make_app_stub() + + with patch("app.gateway.deps.get_local_provider", return_value=provider): + with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"): + from app.gateway.app import _ensure_admin_user + + asyncio.run(_ensure_admin_user(app)) + + provider.create_user.assert_called_once() + call_kwargs = provider.create_user.call_args[1] + assert call_kwargs["email"] == "admin@deerflow.dev" + assert call_kwargs["system_role"] == "admin" + assert call_kwargs["needs_setup"] is True + assert len(call_kwargs["password"]) > 10 # random password generated + + +def test_first_boot_triggers_migration_if_store_present(): + """First boot with store → _migrate_orphaned_threads called.""" + provider = _make_provider(user_count=0) + store = AsyncMock() + store.asearch = AsyncMock(return_value=[]) + app = _make_app_stub(store=store) + + with patch("app.gateway.deps.get_local_provider", return_value=provider): + with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"): + from app.gateway.app import _ensure_admin_user + + asyncio.run(_ensure_admin_user(app)) + + store.asearch.assert_called_once() + + +def test_first_boot_no_store_skips_migration(): + """First boot without store → no crash, migration skipped.""" + provider = _make_provider(user_count=0) + app = _make_app_stub(store=None) + + with patch("app.gateway.deps.get_local_provider", return_value=provider): + with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"): + from app.gateway.app import _ensure_admin_user + + asyncio.run(_ensure_admin_user(app)) + + provider.create_user.assert_called_once() + + +# ── Subsequent boot: needs_setup=True → auto-reset ─────────────────────── + + +def test_needs_setup_true_resets_password(): + """Existing admin with needs_setup=True → password reset + token_version bumped.""" + admin = User( + email="admin@deerflow.dev", + password_hash="old-hash", + system_role="admin", + needs_setup=True, + token_version=0, + created_at=datetime.now(UTC) - timedelta(seconds=30), + ) + provider = _make_provider(user_count=1, admin_user=admin) + app = _make_app_stub() + + with patch("app.gateway.deps.get_local_provider", return_value=provider): + with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"): + from app.gateway.app import _ensure_admin_user + + asyncio.run(_ensure_admin_user(app)) + + # Password was reset + provider.update_user.assert_called_once() + updated = provider.update_user.call_args[0][0] + assert updated.password_hash == "new-hash" + assert updated.token_version == 1 + + +def test_needs_setup_true_consecutive_resets_increment_version(): + """Two boots with needs_setup=True → token_version increments each time.""" + admin = User( + email="admin@deerflow.dev", + password_hash="hash", + system_role="admin", + needs_setup=True, + token_version=3, + created_at=datetime.now(UTC) - timedelta(seconds=30), + ) + provider = _make_provider(user_count=1, admin_user=admin) + app = _make_app_stub() + + with patch("app.gateway.deps.get_local_provider", return_value=provider): + with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"): + from app.gateway.app import _ensure_admin_user + + asyncio.run(_ensure_admin_user(app)) + + updated = provider.update_user.call_args[0][0] + assert updated.token_version == 4 + + +# ── Subsequent boot: needs_setup=False → no-op ────────────────────────── + + +def test_needs_setup_false_no_reset(): + """Admin with needs_setup=False → no password reset, no update.""" + admin = User( + email="admin@deerflow.dev", + password_hash="stable-hash", + system_role="admin", + needs_setup=False, + token_version=2, + ) + provider = _make_provider(user_count=1, admin_user=admin) + app = _make_app_stub() + + with patch("app.gateway.deps.get_local_provider", return_value=provider): + from app.gateway.app import _ensure_admin_user + + asyncio.run(_ensure_admin_user(app)) + + provider.update_user.assert_not_called() + assert admin.password_hash == "stable-hash" + assert admin.token_version == 2 + + +# ── Edge cases ─────────────────────────────────────────────────────────── + + +def test_no_admin_email_found_no_crash(): + """Users exist but no admin@deerflow.dev → no crash, no reset.""" + provider = _make_provider(user_count=3, admin_user=None) + app = _make_app_stub() + + with patch("app.gateway.deps.get_local_provider", return_value=provider): + from app.gateway.app import _ensure_admin_user + + asyncio.run(_ensure_admin_user(app)) + + provider.update_user.assert_not_called() + provider.create_user.assert_not_called() + + +def test_migration_failure_is_non_fatal(): + """_migrate_orphaned_threads exception is caught and logged.""" + provider = _make_provider(user_count=0) + store = AsyncMock() + store.asearch = AsyncMock(side_effect=RuntimeError("store crashed")) + app = _make_app_stub(store=store) + + with patch("app.gateway.deps.get_local_provider", return_value=provider): + with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"): + from app.gateway.app import _ensure_admin_user + + # Should not raise + asyncio.run(_ensure_admin_user(app)) + + provider.create_user.assert_called_once() diff --git a/backend/tests/test_langgraph_auth.py b/backend/tests/test_langgraph_auth.py new file mode 100644 index 000000000..41fbd0340 --- /dev/null +++ b/backend/tests/test_langgraph_auth.py @@ -0,0 +1,312 @@ +"""Tests for LangGraph Server auth handler (langgraph_auth.py). + +Validates that the LangGraph auth layer enforces the same rules as Gateway: + cookie → JWT decode → DB lookup → token_version check → owner filter +""" + +import asyncio +import os +from datetime import timedelta +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch +from uuid import uuid4 + +import pytest + +os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32") + +from langgraph_sdk import Auth + +from app.gateway.auth.config import AuthConfig, set_auth_config +from app.gateway.auth.jwt import create_access_token, decode_token +from app.gateway.auth.models import User +from app.gateway.langgraph_auth import add_owner_filter, authenticate + +# ── Helpers ─────────────────────────────────────────────────────────────── + +_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32" + + +@pytest.fixture(autouse=True) +def _setup_auth_config(): + set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET)) + yield + set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET)) + + +def _req(cookies=None, method="GET", headers=None): + return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {}) + + +def _user(user_id=None, token_version=0): + return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version) + + +def _mock_provider(user=None): + p = AsyncMock() + p.get_user = AsyncMock(return_value=user) + return p + + +# ── @auth.authenticate ─────────────────────────────────────────────────── + + +def test_no_cookie_raises_401(): + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req())) + assert exc.value.status_code == 401 + assert "Not authenticated" in str(exc.value.detail) + + +def test_invalid_jwt_raises_401(): + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req({"access_token": "garbage"}))) + assert exc.value.status_code == 401 + assert "Token error" in str(exc.value.detail) + + +def test_expired_jwt_raises_401(): + token = create_access_token("user-1", expires_delta=timedelta(seconds=-1)) + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req({"access_token": token}))) + assert exc.value.status_code == 401 + + +def test_user_not_found_raises_401(): + token = create_access_token("ghost") + with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)): + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req({"access_token": token}))) + assert exc.value.status_code == 401 + assert "User not found" in str(exc.value.detail) + + +def test_token_version_mismatch_raises_401(): + user = _user(token_version=2) + token = create_access_token(str(user.id), token_version=1) + with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req({"access_token": token}))) + assert exc.value.status_code == 401 + assert "revoked" in str(exc.value.detail).lower() + + +def test_valid_token_returns_user_id(): + user = _user(token_version=0) + token = create_access_token(str(user.id), token_version=0) + with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): + result = asyncio.run(authenticate(_req({"access_token": token}))) + assert result == str(user.id) + + +def test_valid_token_matching_version(): + user = _user(token_version=5) + token = create_access_token(str(user.id), token_version=5) + with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): + result = asyncio.run(authenticate(_req({"access_token": token}))) + assert result == str(user.id) + + +# ── @auth.authenticate edge cases ──────────────────────────────────────── + + +def test_provider_exception_propagates(): + """Provider raises → should not be swallowed silently.""" + token = create_access_token("user-1") + p = AsyncMock() + p.get_user = AsyncMock(side_effect=RuntimeError("DB down")) + with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p): + with pytest.raises(RuntimeError, match="DB down"): + asyncio.run(authenticate(_req({"access_token": token}))) + + +def test_jwt_missing_ver_defaults_to_zero(): + """JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0.""" + import jwt as pyjwt + + uid = str(uuid4()) + raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256") + user = _user(user_id=uid, token_version=0) + with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): + result = asyncio.run(authenticate(_req({"access_token": raw}))) + assert result == uid + + +def test_jwt_missing_ver_rejected_when_user_version_nonzero(): + """JWT without 'ver' (defaults 0) vs user with token_version=1 → 401.""" + import jwt as pyjwt + + uid = str(uuid4()) + raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256") + user = _user(user_id=uid, token_version=1) + with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)): + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req({"access_token": raw}))) + assert exc.value.status_code == 401 + + +def test_wrong_secret_raises_401(): + """Token signed with different secret → 401.""" + import jwt as pyjwt + + raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256") + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req({"access_token": raw}))) + assert exc.value.status_code == 401 + + +# ── @auth.on (owner filter) ────────────────────────────────────────────── + + +class _FakeUser: + """Minimal BaseUser-compatible object without langgraph_api.config dependency.""" + + def __init__(self, identity: str): + self.identity = identity + self.is_authenticated = True + self.display_name = identity + + +def _make_ctx(user_id): + return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[]) + + +def test_filter_injects_user_id(): + value = {} + asyncio.run(add_owner_filter(_make_ctx("user-a"), value)) + assert value["metadata"]["owner_id"] == "user-a" + + +def test_filter_preserves_existing_metadata(): + value = {"metadata": {"title": "hello"}} + asyncio.run(add_owner_filter(_make_ctx("user-a"), value)) + assert value["metadata"]["owner_id"] == "user-a" + assert value["metadata"]["title"] == "hello" + + +def test_filter_returns_user_id_dict(): + result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {})) + assert result == {"owner_id": "user-x"} + + +def test_filter_read_write_consistency(): + value = {} + filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value)) + assert value["metadata"]["owner_id"] == filter_dict["owner_id"] + + +def test_different_users_different_filters(): + f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {})) + f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {})) + assert f_a["owner_id"] != f_b["owner_id"] + + +def test_filter_overrides_conflicting_user_id(): + """If value already has a different user_id in metadata, it gets overwritten.""" + value = {"metadata": {"owner_id": "attacker"}} + asyncio.run(add_owner_filter(_make_ctx("real-owner"), value)) + assert value["metadata"]["owner_id"] == "real-owner" + + +def test_filter_with_empty_metadata(): + """Explicit empty metadata dict is fine.""" + value = {"metadata": {}} + result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value)) + assert value["metadata"]["owner_id"] == "user-z" + assert result == {"owner_id": "user-z"} + + +# ── Gateway parity ─────────────────────────────────────────────────────── + + +def test_shared_jwt_secret(): + token = create_access_token("user-1", token_version=3) + payload = decode_token(token) + from app.gateway.auth.errors import TokenError + + assert not isinstance(payload, TokenError) + assert payload.sub == "user-1" + assert payload.ver == 3 + + +def test_langgraph_json_has_auth_path(): + import json + + config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text()) + assert "auth" in config + assert "langgraph_auth" in config["auth"]["path"] + + +def test_auth_handler_has_both_layers(): + from app.gateway.langgraph_auth import auth + + assert auth._authenticate_handler is not None + assert len(auth._global_handlers) == 1 + + +# ── CSRF in LangGraph auth ────────────────────────────────────────────── + + +def test_csrf_get_no_check(): + """GET requests skip CSRF — should proceed to JWT validation.""" + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req(method="GET"))) + # Rejected by missing cookie, NOT by CSRF + assert exc.value.status_code == 401 + assert "Not authenticated" in str(exc.value.detail) + + +def test_csrf_post_missing_token(): + """POST without CSRF token → 403.""" + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"}))) + assert exc.value.status_code == 403 + assert "CSRF token missing" in str(exc.value.detail) + + +def test_csrf_post_mismatched_token(): + """POST with mismatched CSRF tokens → 403.""" + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run( + authenticate( + _req( + method="POST", + cookies={"access_token": "some-jwt", "csrf_token": "real-token"}, + headers={"x-csrf-token": "wrong-token"}, + ) + ) + ) + assert exc.value.status_code == 403 + assert "mismatch" in str(exc.value.detail) + + +def test_csrf_post_matching_token_proceeds_to_jwt(): + """POST with matching CSRF tokens passes CSRF check, then fails on JWT.""" + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run( + authenticate( + _req( + method="POST", + cookies={"access_token": "garbage", "csrf_token": "same-token"}, + headers={"x-csrf-token": "same-token"}, + ) + ) + ) + # Past CSRF, rejected by JWT decode + assert exc.value.status_code == 401 + assert "Token error" in str(exc.value.detail) + + +def test_csrf_put_requires_token(): + """PUT also requires CSRF.""" + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"}))) + assert exc.value.status_code == 403 + + +def test_csrf_delete_requires_token(): + """DELETE also requires CSRF.""" + with pytest.raises(Auth.exceptions.HTTPException) as exc: + asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"}))) + assert exc.value.status_code == 403 diff --git a/frontend/package.json b/frontend/package.json index 83f69b4e3..8d4ac8526 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -52,7 +52,6 @@ "@xyflow/react": "^12.10.0", "ai": "^6.0.33", "best-effort-json-parser": "^1.2.1", - "better-auth": "^1.3", "canvas-confetti": "^1.9.4", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index e317aaa64..3279f0665 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -113,9 +113,6 @@ importers: best-effort-json-parser: specifier: ^1.2.1 version: 1.2.1 - better-auth: - specifier: ^1.3 - version: 1.4.18(next@16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(vue@3.5.28(typescript@5.9.3)) canvas-confetti: specifier: ^1.9.4 version: 1.9.4 @@ -317,27 +314,6 @@ packages: resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==} engines: {node: '>=6.9.0'} - '@better-auth/core@1.4.18': - resolution: {integrity: sha512-q+awYgC7nkLEBdx2sW0iJjkzgSHlIxGnOpsN1r/O1+a4m7osJNHtfK2mKJSL1I+GfNyIlxJF8WvD/NLuYMpmcg==} - peerDependencies: - '@better-auth/utils': 0.3.0 - '@better-fetch/fetch': 1.1.21 - better-call: 1.1.8 - jose: ^6.1.0 - kysely: ^0.28.5 - nanostores: ^1.0.1 - - '@better-auth/telemetry@1.4.18': - resolution: {integrity: sha512-e5rDF8S4j3Um/0LIVATL2in9dL4lfO2fr2v1Wio4qTMRbfxqnUDTa+6SZtwdeJrbc4O+a3c+IyIpjG9Q/6GpfQ==} - peerDependencies: - '@better-auth/core': 1.4.18 - - '@better-auth/utils@0.3.0': - resolution: {integrity: sha512-W+Adw6ZA6mgvnSnhOki270rwJ42t4XzSK6YWGF//BbVXL6SwCLWfyzBc1lN2m/4RM28KubdBKQ4X5VMoLRNPQw==} - - '@better-fetch/fetch@1.1.21': - resolution: {integrity: sha512-/ImESw0sskqlVR94jB+5+Pxjf+xBwDZF/N5+y2/q4EqD7IARUTSpPfIo8uf39SYpCxyOCtbyYpUrZ3F/k0zT4A==} - '@braintree/sanitize-url@7.1.2': resolution: {integrity: sha512-jigsZK+sMF/cuiB7sERuo9V7N9jx+dhmHHnQyDSVdpZwVutaBu7WvNYqMDLSgFgfB30n452TP3vjDAvFC973mA==} @@ -1116,14 +1092,6 @@ packages: cpu: [x64] os: [win32] - '@noble/ciphers@2.1.1': - resolution: {integrity: sha512-bysYuiVfhxNJuldNXlFEitTVdNnYUc+XNJZd7Qm2a5j1vZHgY+fazadNFWFaMK/2vye0JVlxV3gHmC0WDfAOQw==} - engines: {node: '>= 20.19.0'} - - '@noble/hashes@2.0.1': - resolution: {integrity: sha512-XlOlEbQcE9fmuXxrVTXCTlG2nlRXa9Rj3rr5Ue/+tX+nmkgbX720YHh0VR3hBF9xDvwnb8D2shVGOwNx+ulArw==} - engines: {node: '>= 20.19.0'} - '@nodelib/fs.scandir@2.1.5': resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==} engines: {node: '>= 8'} @@ -2696,76 +2664,6 @@ packages: best-effort-json-parser@1.2.1: resolution: {integrity: sha512-UICSLibQdzS1f+PBsi3u2YE3SsdXcWicHUg3IMvfuaePS2AYnZJdJeKhGv5OM8/mqJwPt79aDrEJ1oa84tELvw==} - better-auth@1.4.18: - resolution: {integrity: sha512-bnyifLWBPcYVltH3RhS7CM62MoelEqC6Q+GnZwfiDWNfepXoQZBjEvn4urcERC7NTKgKq5zNBM8rvPvRBa6xcg==} - peerDependencies: - '@lynx-js/react': '*' - '@prisma/client': ^5.0.0 || ^6.0.0 || ^7.0.0 - '@sveltejs/kit': ^2.0.0 - '@tanstack/react-start': ^1.0.0 - '@tanstack/solid-start': ^1.0.0 - better-sqlite3: ^12.0.0 - drizzle-kit: '>=0.31.4' - drizzle-orm: '>=0.41.0' - mongodb: ^6.0.0 || ^7.0.0 - mysql2: ^3.0.0 - next: ^14.0.0 || ^15.0.0 || ^16.0.0 - pg: ^8.0.0 - prisma: ^5.0.0 || ^6.0.0 || ^7.0.0 - react: ^18.0.0 || ^19.0.0 - react-dom: ^18.0.0 || ^19.0.0 - solid-js: ^1.0.0 - svelte: ^4.0.0 || ^5.0.0 - vitest: ^2.0.0 || ^3.0.0 || ^4.0.0 - vue: ^3.0.0 - peerDependenciesMeta: - '@lynx-js/react': - optional: true - '@prisma/client': - optional: true - '@sveltejs/kit': - optional: true - '@tanstack/react-start': - optional: true - '@tanstack/solid-start': - optional: true - better-sqlite3: - optional: true - drizzle-kit: - optional: true - drizzle-orm: - optional: true - mongodb: - optional: true - mysql2: - optional: true - next: - optional: true - pg: - optional: true - prisma: - optional: true - react: - optional: true - react-dom: - optional: true - solid-js: - optional: true - svelte: - optional: true - vitest: - optional: true - vue: - optional: true - - better-call@1.1.8: - resolution: {integrity: sha512-XMQ2rs6FNXasGNfMjzbyroSwKwYbZ/T3IxruSS6U2MJRsSYh3wYtG3o6H00ZlKZ/C/UPOAD97tqgQJNsxyeTXw==} - peerDependencies: - zod: ^4.0.0 - peerDependenciesMeta: - zod: - optional: true - better-react-mathjax@2.3.0: resolution: {integrity: sha512-K0ceQC+jQmB+NLDogO5HCpqmYf18AU2FxDbLdduYgkHYWZApFggkHE4dIaXCV1NqeoscESYXXo1GSkY6fA295w==} peerDependencies: @@ -3973,9 +3871,6 @@ packages: resolution: {integrity: sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==} hasBin: true - jose@6.1.3: - resolution: {integrity: sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ==} - js-tiktoken@1.0.21: resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==} @@ -4026,10 +3921,6 @@ packages: knitwork@1.3.0: resolution: {integrity: sha512-4LqMNoONzR43B1W0ek0fhXMsDNW/zxa1NdFAVMY+k28pgZLovR4G3PB5MrpTxCy1QaZCqNoiaKPr5w5qZHfSNw==} - kysely@0.28.11: - resolution: {integrity: sha512-zpGIFg0HuoC893rIjYX1BETkVWdDnzTzF5e0kWXJFg5lE0k1/LfNWBejrcnOFu8Q2Rfq/hTDTU7XLUM8QOrpzg==} - engines: {node: '>=20.0.0'} - langium@3.3.1: resolution: {integrity: sha512-QJv/h939gDpvT+9SiLVlY7tZC3xB2qK57v0J04Sh9wpMb6MP1q8gB21L3WIo8T5P1MSMg3Ep14L7KkDCFG3y4w==} engines: {node: '>=16.0.0'} @@ -4458,10 +4349,6 @@ packages: engines: {node: ^18 || >=20} hasBin: true - nanostores@1.1.0: - resolution: {integrity: sha512-yJBmDJr18xy47dbNVlHcgdPrulSn1nhSE6Ns9vTG+Nx9VPT6iV1MD6aQFp/t52zpf82FhLLTXAXr30NuCnxvwA==} - engines: {node: ^20.0.0 || >=22.0.0} - napi-postinstall@0.3.4: resolution: {integrity: sha512-PHI5f1O0EP5xJ9gQmFGMS6IZcrVvTjpXjz7Na41gTE7eE2hK11lg04CECCYEEjdc17EV4DO+fkGEtt7TpTaTiQ==} engines: {node: ^12.20.0 || ^14.18.0 || >=16.0.0} @@ -5050,9 +4937,6 @@ packages: engines: {node: '>=18.0.0', npm: '>=8.0.0'} hasBin: true - rou3@0.7.12: - resolution: {integrity: sha512-iFE4hLDuloSWcD7mjdCDhx2bKcIsYbtOTpfH5MHHLSKMOUyjqQXTeZVa289uuwEGEKFoE/BAPbhaU4B774nceg==} - roughjs@4.6.6: resolution: {integrity: sha512-ZUz/69+SYpFN/g/lUlo2FXcIjRkSu3nDarreVdGGndHEBJ6cXPdKguS8JGxwj5HA5xIbVKSmLgr5b3AWxtRfvQ==} @@ -5105,9 +4989,6 @@ packages: server-only@0.0.1: resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==} - set-cookie-parser@2.7.2: - resolution: {integrity: sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==} - set-function-length@1.2.2: resolution: {integrity: sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==} engines: {node: '>= 0.4'} @@ -5802,27 +5683,6 @@ snapshots: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.28.5 - '@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)': - dependencies: - '@better-auth/utils': 0.3.0 - '@better-fetch/fetch': 1.1.21 - '@standard-schema/spec': 1.1.0 - better-call: 1.1.8(zod@4.3.6) - jose: 6.1.3 - kysely: 0.28.11 - nanostores: 1.1.0 - zod: 4.3.6 - - '@better-auth/telemetry@1.4.18(@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0))': - dependencies: - '@better-auth/core': 1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0) - '@better-auth/utils': 0.3.0 - '@better-fetch/fetch': 1.1.21 - - '@better-auth/utils@0.3.0': {} - - '@better-fetch/fetch@1.1.21': {} - '@braintree/sanitize-url@7.1.2': {} '@cfworker/json-schema@4.1.1': {} @@ -6671,10 +6531,6 @@ snapshots: '@next/swc-win32-x64-msvc@16.1.7': optional: true - '@noble/ciphers@2.1.1': {} - - '@noble/hashes@2.0.1': {} - '@nodelib/fs.scandir@2.1.5': dependencies: '@nodelib/fs.stat': 2.0.5 @@ -8242,35 +8098,6 @@ snapshots: best-effort-json-parser@1.2.1: {} - better-auth@1.4.18(next@16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(vue@3.5.28(typescript@5.9.3)): - dependencies: - '@better-auth/core': 1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0) - '@better-auth/telemetry': 1.4.18(@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)) - '@better-auth/utils': 0.3.0 - '@better-fetch/fetch': 1.1.21 - '@noble/ciphers': 2.1.1 - '@noble/hashes': 2.0.1 - better-call: 1.1.8(zod@4.3.6) - defu: 6.1.4 - jose: 6.1.3 - kysely: 0.28.11 - nanostores: 1.1.0 - zod: 4.3.6 - optionalDependencies: - next: 16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - react: 19.2.4 - react-dom: 19.2.4(react@19.2.4) - vue: 3.5.28(typescript@5.9.3) - - better-call@1.1.8(zod@4.3.6): - dependencies: - '@better-auth/utils': 0.3.0 - '@better-fetch/fetch': 1.1.21 - rou3: 0.7.12 - set-cookie-parser: 2.7.2 - optionalDependencies: - zod: 4.3.6 - better-react-mathjax@2.3.0(react@19.2.4): dependencies: mathjax-full: 3.2.2 @@ -9786,8 +9613,6 @@ snapshots: jiti@2.6.1: {} - jose@6.1.3: {} - js-tiktoken@1.0.21: dependencies: base64-js: 1.5.1 @@ -9833,8 +9658,6 @@ snapshots: knitwork@1.3.0: {} - kysely@0.28.11: {} - langium@3.3.1: dependencies: chevrotain: 11.0.3 @@ -10529,8 +10352,6 @@ snapshots: nanoid@5.1.6: {} - nanostores@1.1.0: {} - napi-postinstall@0.3.4: {} natural-compare@1.4.0: {} @@ -11305,8 +11126,6 @@ snapshots: '@rollup/rollup-win32-x64-msvc': 4.60.0 fsevents: 2.3.3 - rou3@0.7.12: {} - roughjs@4.6.6: dependencies: hachure-fill: 0.5.2 @@ -11373,8 +11192,6 @@ snapshots: server-only@0.0.1: {} - set-cookie-parser@2.7.2: {} - set-function-length@1.2.2: dependencies: define-data-property: 1.1.4 diff --git a/frontend/src/app/(auth)/layout.tsx b/frontend/src/app/(auth)/layout.tsx new file mode 100644 index 000000000..b916def52 --- /dev/null +++ b/frontend/src/app/(auth)/layout.tsx @@ -0,0 +1,45 @@ +import Link from "next/link"; +import { redirect } from "next/navigation"; +import { type ReactNode } from "react"; + +import { AuthProvider } from "@/core/auth/AuthProvider"; +import { getServerSideUser } from "@/core/auth/server"; +import { assertNever } from "@/core/auth/types"; + +export const dynamic = "force-dynamic"; + +export default async function AuthLayout({ + children, +}: { + children: ReactNode; +}) { + const result = await getServerSideUser(); + + switch (result.tag) { + case "authenticated": + redirect("/workspace"); + case "needs_setup": + // Allow access to setup page + return {children}; + case "unauthenticated": + return {children}; + case "gateway_unavailable": + return ( +
+

+ Service temporarily unavailable. +

+ + Retry + +
+ ); + case "config_error": + throw new Error(result.message); + default: + assertNever(result); + } +} diff --git a/frontend/src/app/(auth)/login/page.tsx b/frontend/src/app/(auth)/login/page.tsx new file mode 100644 index 000000000..90ca15238 --- /dev/null +++ b/frontend/src/app/(auth)/login/page.tsx @@ -0,0 +1,183 @@ +"use client"; + +import Link from "next/link"; +import { useRouter, useSearchParams } from "next/navigation"; +import { useEffect, useState } from "react"; + +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { useAuth } from "@/core/auth/AuthProvider"; +import { parseAuthError } from "@/core/auth/types"; + +/** + * Validate next parameter + * Prevent open redirect attacks + * Per RFC-001: Only allow relative paths starting with / + */ +function validateNextParam(next: string | null): string | null { + if (!next) { + return null; + } + + // Need start with / (relative path) + if (!next.startsWith("/")) { + return null; + } + + // Disallow protocol-relative URLs + if ( + next.startsWith("//") || + next.startsWith("http://") || + next.startsWith("https://") + ) { + return null; + } + + // Disallow URLs with different protocols (e.g., javascript:, data:, etc) + if (next.includes(":") && !next.startsWith("/")) { + return null; + } + + // Valid relative path + return next; +} + +export default function LoginPage() { + const router = useRouter(); + const searchParams = useSearchParams(); + const { isAuthenticated } = useAuth(); + + const [email, setEmail] = useState(""); + const [password, setPassword] = useState(""); + const [isLogin, setIsLogin] = useState(true); + const [error, setError] = useState(""); + const [loading, setLoading] = useState(false); + + // Get next parameter for validated redirect + const nextParam = searchParams.get("next"); + const redirectPath = validateNextParam(nextParam) ?? "/workspace"; + + // Redirect if already authenticated (client-side, post-login) + useEffect(() => { + if (isAuthenticated) { + router.push(redirectPath); + } + }, [isAuthenticated, redirectPath, router]); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setError(""); + setLoading(true); + + try { + const endpoint = isLogin + ? "/api/v1/auth/login/local" + : "/api/v1/auth/register"; + const body = isLogin + ? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}` + : JSON.stringify({ email, password }); + + const headers: HeadersInit = isLogin + ? { "Content-Type": "application/x-www-form-urlencoded" } + : { "Content-Type": "application/json" }; + + const res = await fetch(endpoint, { + method: "POST", + headers, + body, + credentials: "include", // Important: include HttpOnly cookie + }); + + if (!res.ok) { + const data = await res.json(); + const authError = parseAuthError(data); + setError(authError.message); + return; + } + + // Both login and register set a cookie — redirect to workspace + router.push(redirectPath); + } catch (_err) { + setError("Network error. Please try again."); + } finally { + setLoading(false); + } + }; + + return ( +
+
+
+

DeerFlow

+

+ {isLogin ? "Sign in to your account" : "Create a new account"} +

+
+ +
+
+ + setEmail(e.target.value)} + placeholder="you@example.com" + required + className="mt-1 bg-white text-black" + /> +
+ +
+ + setPassword(e.target.value)} + placeholder="•••••••" + required + minLength={isLogin ? 6 : 8} + className="mt-1 bg-white text-black" + /> +
+ + {error &&

{error}

} + + +
+ +
+ +
+ +
+ + ← Back to home + +
+
+
+ ); +} diff --git a/frontend/src/app/(auth)/setup/page.tsx b/frontend/src/app/(auth)/setup/page.tsx new file mode 100644 index 000000000..e70d1efc6 --- /dev/null +++ b/frontend/src/app/(auth)/setup/page.tsx @@ -0,0 +1,115 @@ +"use client"; + +import { useRouter } from "next/navigation"; +import { useState } from "react"; + +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { getCsrfHeaders } from "@/core/api/fetcher"; +import { parseAuthError } from "@/core/auth/types"; + +export default function SetupPage() { + const router = useRouter(); + const [email, setEmail] = useState(""); + const [newPassword, setNewPassword] = useState(""); + const [confirmPassword, setConfirmPassword] = useState(""); + const [currentPassword, setCurrentPassword] = useState(""); + const [error, setError] = useState(""); + const [loading, setLoading] = useState(false); + + const handleSetup = async (e: React.FormEvent) => { + e.preventDefault(); + setError(""); + + if (newPassword !== confirmPassword) { + setError("Passwords do not match"); + return; + } + if (newPassword.length < 8) { + setError("Password must be at least 8 characters"); + return; + } + + setLoading(true); + try { + const res = await fetch("/api/v1/auth/change-password", { + method: "POST", + headers: { + "Content-Type": "application/json", + ...getCsrfHeaders(), + }, + credentials: "include", + body: JSON.stringify({ + current_password: currentPassword, + new_password: newPassword, + new_email: email || undefined, + }), + }); + + if (!res.ok) { + const data = await res.json(); + const authError = parseAuthError(data); + setError(authError.message); + return; + } + + router.push("/workspace"); + } catch { + setError("Network error. Please try again."); + } finally { + setLoading(false); + } + }; + + return ( +
+
+
+

DeerFlow

+

+ Complete admin account setup +

+

+ Set your real email and a new password. +

+
+
+ setEmail(e.target.value)} + required + /> + setCurrentPassword(e.target.value)} + required + /> + setNewPassword(e.target.value)} + required + minLength={8} + /> + setConfirmPassword(e.target.value)} + required + minLength={8} + /> + {error &&

{error}

} + +
+
+
+ ); +} diff --git a/frontend/src/app/api/auth/[...all]/route.ts b/frontend/src/app/api/auth/[...all]/route.ts deleted file mode 100644 index cde6018a8..000000000 --- a/frontend/src/app/api/auth/[...all]/route.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { toNextJsHandler } from "better-auth/next-js"; - -import { auth } from "@/server/better-auth"; - -export const { GET, POST } = toNextJsHandler(auth.handler); diff --git a/frontend/src/app/workspace/layout.tsx b/frontend/src/app/workspace/layout.tsx index 417c933d4..fa19025a0 100644 --- a/frontend/src/app/workspace/layout.tsx +++ b/frontend/src/app/workspace/layout.tsx @@ -1,47 +1,58 @@ -"use client"; +import Link from "next/link"; +import { redirect } from "next/navigation"; -import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; -import { useCallback, useEffect, useLayoutEffect, useState } from "react"; -import { Toaster } from "sonner"; +import { AuthProvider } from "@/core/auth/AuthProvider"; +import { getServerSideUser } from "@/core/auth/server"; +import { assertNever } from "@/core/auth/types"; -import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar"; -import { CommandPalette } from "@/components/workspace/command-palette"; -import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar"; -import { getLocalSettings, useLocalSettings } from "@/core/settings"; +import { WorkspaceContent } from "./workspace-content"; -const queryClient = new QueryClient(); +export const dynamic = "force-dynamic"; -export default function WorkspaceLayout({ +export default async function WorkspaceLayout({ children, }: Readonly<{ children: React.ReactNode }>) { - const [settings, setSettings] = useLocalSettings(); - const [open, setOpen] = useState(false); // SSR default: open (matches server render) - useLayoutEffect(() => { - // Runs synchronously before first paint on the client — no visual flash - setOpen(!getLocalSettings().layout.sidebar_collapsed); - }, []); - useEffect(() => { - setOpen(!settings.layout.sidebar_collapsed); - }, [settings.layout.sidebar_collapsed]); - const handleOpenChange = useCallback( - (open: boolean) => { - setOpen(open); - setSettings("layout", { sidebar_collapsed: !open }); - }, - [setSettings], - ); - return ( - - - - {children} - - - - - ); + const result = await getServerSideUser(); + + switch (result.tag) { + case "authenticated": + return ( + + {children} + + ); + case "needs_setup": + redirect("/setup"); + case "unauthenticated": + redirect("/login"); + case "gateway_unavailable": + return ( +
+

+ Service temporarily unavailable. +

+

+ The backend may be restarting. Please wait a moment and try again. +

+
+ + Retry + + + Logout & Reset + +
+
+ ); + case "config_error": + throw new Error(result.message); + default: + assertNever(result); + } } diff --git a/frontend/src/app/workspace/workspace-content.tsx b/frontend/src/app/workspace/workspace-content.tsx new file mode 100644 index 000000000..960ad28a2 --- /dev/null +++ b/frontend/src/app/workspace/workspace-content.tsx @@ -0,0 +1,50 @@ +"use client"; + +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { useCallback, useEffect, useLayoutEffect, useState } from "react"; +import { Toaster } from "sonner"; + +import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar"; +import { CommandPalette } from "@/components/workspace/command-palette"; +import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar"; +import { getLocalSettings, useLocalSettings } from "@/core/settings"; + +export function WorkspaceContent({ + children, +}: Readonly<{ children: React.ReactNode }>) { + const [queryClient] = useState(() => new QueryClient()); + const [settings, setSettings] = useLocalSettings(); + const [open, setOpen] = useState(false); // SSR default: open (matches server render) + + useLayoutEffect(() => { + // Runs synchronously before first paint on the client — no visual flash + setOpen(!getLocalSettings().layout.sidebar_collapsed); + }, []); + + useEffect(() => { + setOpen(!settings.layout.sidebar_collapsed); + }, [settings.layout.sidebar_collapsed]); + + const handleOpenChange = useCallback( + (open: boolean) => { + setOpen(open); + setSettings("layout", { sidebar_collapsed: !open }); + }, + [setSettings], + ); + + return ( + + + + {children} + + + + + ); +} diff --git a/frontend/src/core/api/fetcher.ts b/frontend/src/core/api/fetcher.ts new file mode 100644 index 000000000..a9dd30387 --- /dev/null +++ b/frontend/src/core/api/fetcher.ts @@ -0,0 +1,39 @@ +import { buildLoginUrl } from "@/core/auth/types"; + +/** + * Fetch with credentials. Automatically redirects to login on 401. + */ +export async function fetchWithAuth( + input: RequestInfo | string, + init?: RequestInit, +): Promise { + const url = typeof input === "string" ? input : input.url; + const res = await fetch(url, { + ...init, + credentials: "include", + }); + + if (res.status === 401) { + window.location.href = buildLoginUrl(window.location.pathname); + throw new Error("Unauthorized"); + } + + return res; +} + +/** + * Build headers for CSRF-protected requests + * Per RFC-001: Double Submit Cookie pattern + */ +export function getCsrfHeaders(): HeadersInit { + const token = getCsrfToken(); + return token ? { "X-CSRF-Token": token } : {}; +} + +/** + * Get CSRF token from cookie + */ +function getCsrfToken(): string | null { + const match = /csrf_token=([^;]+)/.exec(document.cookie); + return match?.[1] ?? null; +} diff --git a/frontend/src/core/auth/AuthProvider.tsx b/frontend/src/core/auth/AuthProvider.tsx new file mode 100644 index 000000000..652cc49b8 --- /dev/null +++ b/frontend/src/core/auth/AuthProvider.tsx @@ -0,0 +1,165 @@ +"use client"; + +import { useRouter, usePathname } from "next/navigation"; +import React, { + createContext, + useContext, + useState, + useCallback, + useEffect, + type ReactNode, +} from "react"; + +import { type User, buildLoginUrl } from "./types"; + +// Re-export for consumers +export type { User }; + +/** + * Authentication context provided to consuming components + */ +interface AuthContextType { + user: User | null; + isAuthenticated: boolean; + isLoading: boolean; + logout: () => Promise; + refreshUser: () => Promise; +} + +const AuthContext = createContext(undefined); + +interface AuthProviderProps { + children: ReactNode; + initialUser: User | null; +} + +/** + * AuthProvider - Unified authentication context for the application + * + * Per RFC-001: + * - Only holds display information (user), never JWT or tokens + * - initialUser comes from server-side guard, avoiding client flicker + * - Provides logout and refresh capabilities + */ +export function AuthProvider({ children, initialUser }: AuthProviderProps) { + const [user, setUser] = useState(initialUser); + const [isLoading, setIsLoading] = useState(false); + const router = useRouter(); + const pathname = usePathname(); + + const isAuthenticated = user !== null; + + /** + * Fetch current user from FastAPI + * Used when initialUser might be stale (e.g., after tab was inactive) + */ + const refreshUser = useCallback(async () => { + try { + setIsLoading(true); + const res = await fetch("/api/v1/auth/me", { + credentials: "include", + }); + + if (res.ok) { + const data = await res.json(); + setUser(data); + } else if (res.status === 401) { + // Session expired or invalid + setUser(null); + // Redirect to login if on a protected route + if (pathname?.startsWith("/workspace")) { + router.push(buildLoginUrl(pathname)); + } + } + } catch (err) { + console.error("Failed to refresh user:", err); + setUser(null); + } finally { + setIsLoading(false); + } + }, [pathname, router]); + + /** + * Logout - call FastAPI logout endpoint and clear local state + * Per RFC-001: Immediately clear local state, don't wait for server confirmation + */ + const logout = useCallback(async () => { + // Immediately clear local state to prevent UI flicker + setUser(null); + + try { + await fetch("/api/v1/auth/logout", { + method: "POST", + credentials: "include", + }); + } catch (err) { + console.error("Logout request failed:", err); + // Still redirect even if logout request fails + } + + // Redirect to home page + router.push("/"); + }, [router]); + + /** + * Handle visibility change - refresh user when tab becomes visible again. + * Throttled to at most once per 60 s to avoid spamming the backend on rapid tab switches. + */ + const lastCheckRef = React.useRef(0); + + useEffect(() => { + const handleVisibilityChange = () => { + if (document.visibilityState !== "visible" || user === null) return; + const now = Date.now(); + if (now - lastCheckRef.current < 60_000) return; + lastCheckRef.current = now; + void refreshUser(); + }; + + document.addEventListener("visibilitychange", handleVisibilityChange); + return () => { + document.removeEventListener("visibilitychange", handleVisibilityChange); + }; + }, [user, refreshUser]); + + const value: AuthContextType = { + user, + isAuthenticated, + isLoading, + logout, + refreshUser, + }; + + return {children}; +} + +/** + * Hook to access authentication context + * Throws if used outside AuthProvider - this is intentional for proper usage + */ +export function useAuth(): AuthContextType { + const context = useContext(AuthContext); + if (context === undefined) { + throw new Error("useAuth must be used within an AuthProvider"); + } + return context; +} + +/** + * Hook to require authentication - redirects to login if not authenticated + * Useful for client-side checks in addition to server-side guards + */ +export function useRequireAuth(): AuthContextType { + const auth = useAuth(); + const router = useRouter(); + const pathname = usePathname(); + + useEffect(() => { + // Only redirect if we're sure user is not authenticated (not just loading) + if (!auth.isLoading && !auth.isAuthenticated) { + router.push(buildLoginUrl(pathname || "/workspace")); + } + }, [auth.isAuthenticated, auth.isLoading, router, pathname]); + + return auth; +} diff --git a/frontend/src/core/auth/gateway-config.ts b/frontend/src/core/auth/gateway-config.ts new file mode 100644 index 000000000..61c6ae850 --- /dev/null +++ b/frontend/src/core/auth/gateway-config.ts @@ -0,0 +1,34 @@ +import { z } from "zod"; + +const gatewayConfigSchema = z.object({ + internalGatewayUrl: z.string().url(), + trustedOrigins: z.array(z.string()).min(1), +}); + +export type GatewayConfig = z.infer; + +let _cached: GatewayConfig | null = null; + +export function getGatewayConfig(): GatewayConfig { + if (_cached) return _cached; + + const isDev = process.env.NODE_ENV === "development"; + + const rawUrl = process.env.DEER_FLOW_INTERNAL_GATEWAY_BASE_URL?.trim(); + const internalGatewayUrl = + rawUrl?.replace(/\/+$/, "") ?? + (isDev ? "http://localhost:8001" : undefined); + + const rawOrigins = process.env.DEER_FLOW_TRUSTED_ORIGINS?.trim(); + const trustedOrigins = rawOrigins + ? rawOrigins + .split(",") + .map((s) => s.trim()) + .filter(Boolean) + : isDev + ? ["http://localhost:3000"] + : undefined; + + _cached = gatewayConfigSchema.parse({ internalGatewayUrl, trustedOrigins }); + return _cached; +} diff --git a/frontend/src/core/auth/proxy-policy.ts b/frontend/src/core/auth/proxy-policy.ts new file mode 100644 index 000000000..9e6f1f424 --- /dev/null +++ b/frontend/src/core/auth/proxy-policy.ts @@ -0,0 +1,55 @@ +export interface ProxyPolicy { + /** Allowed upstream path prefixes */ + readonly allowedPaths: readonly string[]; + /** Request headers to strip before forwarding */ + readonly strippedRequestHeaders: ReadonlySet; + /** Response headers to strip before returning */ + readonly strippedResponseHeaders: ReadonlySet; + /** Credential mode: which cookie to forward */ + readonly credential: { readonly type: "cookie"; readonly name: string }; + /** Timeout in ms */ + readonly timeoutMs: number; + /** CSRF: required for non-GET/HEAD */ + readonly csrf: boolean; +} + +export const LANGGRAPH_COMPAT_POLICY: ProxyPolicy = { + allowedPaths: [ + "threads", + "runs", + "assistants", + "store", + "models", + "mcp", + "skills", + "memory", + ], + strippedRequestHeaders: new Set([ + "host", + "connection", + "keep-alive", + "transfer-encoding", + "te", + "trailer", + "upgrade", + "authorization", + "x-api-key", + "origin", + "referer", + "proxy-authorization", + "proxy-authenticate", + ]), + strippedResponseHeaders: new Set([ + "connection", + "keep-alive", + "transfer-encoding", + "te", + "trailer", + "upgrade", + "content-length", + "set-cookie", + ]), + credential: { type: "cookie", name: "access_token" }, + timeoutMs: 120_000, + csrf: true, +}; diff --git a/frontend/src/core/auth/server.ts b/frontend/src/core/auth/server.ts new file mode 100644 index 000000000..4229143aa --- /dev/null +++ b/frontend/src/core/auth/server.ts @@ -0,0 +1,57 @@ +import { cookies } from "next/headers"; + +import { getGatewayConfig } from "./gateway-config"; +import { type AuthResult, userSchema } from "./types"; + +const SSR_AUTH_TIMEOUT_MS = 5_000; + +/** + * Fetch the authenticated user from the gateway using the request's cookies. + * Returns a tagged AuthResult — callers use exhaustive switch, no try/catch. + */ +export async function getServerSideUser(): Promise { + const cookieStore = await cookies(); + const sessionCookie = cookieStore.get("access_token"); + + let internalGatewayUrl: string; + try { + internalGatewayUrl = getGatewayConfig().internalGatewayUrl; + } catch (err) { + return { tag: "config_error", message: String(err) }; + } + + if (!sessionCookie) return { tag: "unauthenticated" }; + + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), SSR_AUTH_TIMEOUT_MS); + + try { + const res = await fetch(`${internalGatewayUrl}/api/v1/auth/me`, { + headers: { Cookie: `access_token=${sessionCookie.value}` }, + cache: "no-store", + signal: controller.signal, + }); + clearTimeout(timeout); // Clear immediately — covers all response branches + + if (res.ok) { + const parsed = userSchema.safeParse(await res.json()); + if (!parsed.success) { + console.error("[SSR auth] Malformed /auth/me response:", parsed.error); + return { tag: "gateway_unavailable" }; + } + if (parsed.data.needs_setup) { + return { tag: "needs_setup", user: parsed.data }; + } + return { tag: "authenticated", user: parsed.data }; + } + if (res.status === 401 || res.status === 403) { + return { tag: "unauthenticated" }; + } + console.error(`[SSR auth] /api/v1/auth/me responded ${res.status}`); + return { tag: "gateway_unavailable" }; + } catch (err) { + clearTimeout(timeout); + console.error("[SSR auth] Failed to reach gateway:", err); + return { tag: "gateway_unavailable" }; + } +} diff --git a/frontend/src/core/auth/types.ts b/frontend/src/core/auth/types.ts new file mode 100644 index 000000000..4cf42583e --- /dev/null +++ b/frontend/src/core/auth/types.ts @@ -0,0 +1,72 @@ +import { z } from "zod"; + +// ── User schema (single source of truth) ────────────────────────── + +export const userSchema = z.object({ + id: z.string(), + email: z.string().email(), + system_role: z.enum(["admin", "user"]), + needs_setup: z.boolean().optional().default(false), +}); + +export type User = z.infer; + +// ── SSR auth result (tagged union) ──────────────────────────────── + +export type AuthResult = + | { tag: "authenticated"; user: User } + | { tag: "needs_setup"; user: User } + | { tag: "unauthenticated" } + | { tag: "gateway_unavailable" } + | { tag: "config_error"; message: string }; + +export function assertNever(x: never): never { + throw new Error(`Unexpected auth result: ${JSON.stringify(x)}`); +} + +export function buildLoginUrl(returnPath: string): string { + return `/login?next=${encodeURIComponent(returnPath)}`; +} + +// ── Backend error response parsing ──────────────────────────────── + +const AUTH_ERROR_CODES = [ + "invalid_credentials", + "token_expired", + "token_invalid", + "user_not_found", + "email_already_exists", + "provider_not_found", + "not_authenticated", +] as const; + +export type AuthErrorCode = (typeof AUTH_ERROR_CODES)[number]; + +export interface AuthErrorResponse { + code: AuthErrorCode; + message: string; +} + +const authErrorSchema = z.object({ + code: z.enum(AUTH_ERROR_CODES), + message: z.string(), +}); + +export function parseAuthError(data: unknown): AuthErrorResponse { + // Try top-level {code, message} first + const parsed = authErrorSchema.safeParse(data); + if (parsed.success) return parsed.data; + + // Unwrap FastAPI's {detail: {code, message}} envelope + if (typeof data === "object" && data !== null && "detail" in data) { + const detail = (data as Record).detail; + const nested = authErrorSchema.safeParse(detail); + if (nested.success) return nested.data; + // Legacy string-detail responses + if (typeof detail === "string") { + return { code: "invalid_credentials", message: detail }; + } + } + + return { code: "invalid_credentials", message: "Authentication failed" }; +} diff --git a/frontend/src/env.js b/frontend/src/env.js index f00fa7a6c..ea90cac5d 100644 --- a/frontend/src/env.js +++ b/frontend/src/env.js @@ -7,12 +7,6 @@ export const env = createEnv({ * isn't built with invalid env vars. */ server: { - BETTER_AUTH_SECRET: - process.env.NODE_ENV === "production" - ? z.string() - : z.string().optional(), - BETTER_AUTH_GITHUB_CLIENT_ID: z.string().optional(), - BETTER_AUTH_GITHUB_CLIENT_SECRET: z.string().optional(), GITHUB_OAUTH_TOKEN: z.string().optional(), NODE_ENV: z .enum(["development", "test", "production"]) @@ -35,10 +29,6 @@ export const env = createEnv({ * middlewares) or client-side so we need to destruct manually. */ runtimeEnv: { - BETTER_AUTH_SECRET: process.env.BETTER_AUTH_SECRET, - BETTER_AUTH_GITHUB_CLIENT_ID: process.env.BETTER_AUTH_GITHUB_CLIENT_ID, - BETTER_AUTH_GITHUB_CLIENT_SECRET: - process.env.BETTER_AUTH_GITHUB_CLIENT_SECRET, NODE_ENV: process.env.NODE_ENV, NEXT_PUBLIC_BACKEND_BASE_URL: process.env.NEXT_PUBLIC_BACKEND_BASE_URL, diff --git a/frontend/src/server/better-auth/client.ts b/frontend/src/server/better-auth/client.ts deleted file mode 100644 index 493f84993..000000000 --- a/frontend/src/server/better-auth/client.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { createAuthClient } from "better-auth/react"; - -export const authClient = createAuthClient(); - -export type Session = typeof authClient.$Infer.Session; diff --git a/frontend/src/server/better-auth/config.ts b/frontend/src/server/better-auth/config.ts deleted file mode 100644 index abf50faca..000000000 --- a/frontend/src/server/better-auth/config.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { betterAuth } from "better-auth"; - -export const auth = betterAuth({ - emailAndPassword: { - enabled: true, - }, -}); - -export type Session = typeof auth.$Infer.Session; diff --git a/frontend/src/server/better-auth/index.ts b/frontend/src/server/better-auth/index.ts deleted file mode 100644 index d705e873e..000000000 --- a/frontend/src/server/better-auth/index.ts +++ /dev/null @@ -1 +0,0 @@ -export { auth } from "./config"; diff --git a/frontend/src/server/better-auth/server.ts b/frontend/src/server/better-auth/server.ts deleted file mode 100644 index 064cd349c..000000000 --- a/frontend/src/server/better-auth/server.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { headers } from "next/headers"; -import { cache } from "react"; - -import { auth } from "."; - -export const getSession = cache(async () => - auth.api.getSession({ headers: await headers() }), -);