"""CSRF protection middleware and helpers for cookie-based auth flows.""" import secrets from collections.abc import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from starlette.types import ASGIApp CSRF_COOKIE_NAME = "csrf_token" CSRF_HEADER_NAME = "X-CSRF-Token" CSRF_TOKEN_LENGTH = 64 # bytes def is_secure_request(request: Request) -> bool: """Detect whether the original client request was made over HTTPS.""" return request.headers.get("x-forwarded-proto", request.url.scheme) == "https" def generate_csrf_token() -> str: """Generate a secure random CSRF token.""" return secrets.token_urlsafe(CSRF_TOKEN_LENGTH) def should_check_csrf(request: Request) -> bool: """Determine if a request needs CSRF validation.""" if request.method not in ("POST", "PUT", "DELETE", "PATCH"): return False path = request.url.path.rstrip("/") if path == "/api/v1/auth/me": return False return True _AUTH_EXEMPT_PATHS: frozenset[str] = frozenset( { "/api/v1/auth/login/local", "/api/v1/auth/logout", "/api/v1/auth/register", "/api/v1/auth/initialize", } ) def is_auth_endpoint(request: Request) -> bool: """Check if the request is to an auth endpoint.""" return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS class CSRFMiddleware(BaseHTTPMiddleware): """Implement CSRF protection using the double-submit cookie pattern.""" def __init__(self, app: ASGIApp) -> None: super().__init__(app) async def dispatch(self, request: Request, call_next: Callable) -> Response: _is_auth = is_auth_endpoint(request) if should_check_csrf(request) and not _is_auth: cookie_token = request.cookies.get(CSRF_COOKIE_NAME) header_token = request.headers.get(CSRF_HEADER_NAME) if not cookie_token or not header_token: return JSONResponse( status_code=403, content={"detail": "CSRF token missing. Include X-CSRF-Token header."}, ) if not secrets.compare_digest(cookie_token, header_token): return JSONResponse( status_code=403, content={"detail": "CSRF token mismatch."}, ) response = await call_next(request) if _is_auth and request.method == "POST": csrf_token = generate_csrf_token() response.set_cookie( key=CSRF_COOKIE_NAME, value=csrf_token, httponly=False, secure=is_secure_request(request), samesite="strict", ) return response def get_csrf_token(request: Request) -> str | None: """Get the CSRF token from the current request's cookies.""" return request.cookies.get(CSRF_COOKIE_NAME) __all__ = [ "CSRF_COOKIE_NAME", "CSRF_HEADER_NAME", "CSRFMiddleware", "generate_csrf_token", "get_csrf_token", "is_auth_endpoint", "is_secure_request", "should_check_csrf", ]