From 2b0e62f679eb6437ed5b07de6e140bcdbaa4f681 Mon Sep 17 00:00:00 2001 From: Hinotobi Date: Thu, 7 May 2026 07:58:06 +0800 Subject: [PATCH] [security] fix(auth): reject cross-site auth POSTs (#2740) * fix(security): reject cross-site auth posts * fix(auth): align secure cookie proxy scheme handling --------- Co-authored-by: Willem Jiang --- backend/app/gateway/csrf_middleware.py | 113 ++++++++++++- backend/tests/test_csrf_middleware.py | 219 +++++++++++++++++++++++++ 2 files changed, 331 insertions(+), 1 deletion(-) create mode 100644 backend/tests/test_csrf_middleware.py diff --git a/backend/app/gateway/csrf_middleware.py b/backend/app/gateway/csrf_middleware.py index 4c9b0f36a..08e95be4b 100644 --- a/backend/app/gateway/csrf_middleware.py +++ b/backend/app/gateway/csrf_middleware.py @@ -4,8 +4,10 @@ Per RFC-001: State-changing operations require CSRF protection. """ +import os import secrets from collections.abc import Callable +from urllib.parse import urlsplit from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware @@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes def is_secure_request(request: Request) -> bool: """Detect whether the original client request was made over HTTPS.""" - return request.headers.get("x-forwarded-proto", request.url.scheme) == "https" + return _request_scheme(request) == "https" def generate_csrf_token() -> str: @@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool: return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS +def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str: + """Return normalized host[:port], omitting default ports.""" + host = hostname.lower() + if ":" in host and not host.startswith("["): + host = f"[{host}]" + + if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443): + return host + return f"{host}:{port}" + + +def _normalize_origin(origin: str) -> str | None: + """Return a normalized scheme://host[:port] origin, or None for invalid input.""" + try: + parsed = urlsplit(origin.strip()) + port = parsed.port + except ValueError: + return None + + scheme = parsed.scheme.lower() + if scheme not in {"http", "https"} or not parsed.hostname: + return None + + # Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values. + if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment: + return None + + return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}" + + +def _configured_cors_origins() -> set[str]: + """Return explicit configured browser origins that may call auth routes.""" + origins = set() + for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","): + origin = raw_origin.strip() + if not origin or origin == "*": + continue + normalized = _normalize_origin(origin) + if normalized: + origins.add(normalized) + return origins + + +def _first_header_value(value: str | None) -> str | None: + """Return the first value from a comma-separated proxy header.""" + if not value: + return None + first = value.split(",", 1)[0].strip() + return first or None + + +def _forwarded_param(request: Request, name: str) -> str | None: + """Extract a parameter from the first RFC 7239 Forwarded header entry.""" + forwarded = _first_header_value(request.headers.get("forwarded")) + if not forwarded: + return None + + for part in forwarded.split(";"): + key, sep, value = part.strip().partition("=") + if sep and key.lower() == name: + return value.strip().strip('"') or None + return None + + +def _request_scheme(request: Request) -> str: + """Resolve the original request scheme from trusted proxy headers.""" + scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme + return scheme.lower() + + +def _request_origin(request: Request) -> str | None: + """Build the origin for the URL the browser is targeting.""" + scheme = _request_scheme(request) + host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc + + forwarded_port = _first_header_value(request.headers.get("x-forwarded-port")) + if forwarded_port and ":" not in host.rsplit("]", 1)[-1]: + host = f"{host}:{forwarded_port}" + + return _normalize_origin(f"{scheme}://{host}") + + +def is_allowed_auth_origin(request: Request) -> bool: + """Allow auth POSTs only from the same origin or explicit configured origins. + + Login/register/initialize are exempt from the double-submit token because + first-time browser clients do not have a CSRF token yet. They still create + a session cookie, so browser requests with a hostile Origin header must be + rejected to prevent login CSRF / session fixation. Requests without Origin + are allowed for non-browser clients such as curl and mobile integrations. + """ + origin = request.headers.get("origin") + if not origin: + return True + + normalized_origin = _normalize_origin(origin) + if normalized_origin is None: + return False + + request_origin = _request_origin(request) + return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin) + + class CSRFMiddleware(BaseHTTPMiddleware): """Middleware that implements CSRF protection using Double Submit Cookie pattern.""" @@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: Callable) -> Response: _is_auth = is_auth_endpoint(request) + if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request): + return JSONResponse( + status_code=403, + content={"detail": "Cross-site auth request denied."}, + ) + if should_check_csrf(request) and not _is_auth: cookie_token = request.cookies.get(CSRF_COOKIE_NAME) header_token = request.headers.get(CSRF_HEADER_NAME) diff --git a/backend/tests/test_csrf_middleware.py b/backend/tests/test_csrf_middleware.py new file mode 100644 index 000000000..247e24cda --- /dev/null +++ b/backend/tests/test_csrf_middleware.py @@ -0,0 +1,219 @@ +"""Tests for CSRF middleware.""" + +from fastapi import FastAPI +from starlette.testclient import TestClient + +from app.gateway.csrf_middleware import CSRFMiddleware + + +def _make_app() -> FastAPI: + app = FastAPI() + app.add_middleware(CSRFMiddleware) + + @app.post("/api/v1/auth/login/local") + async def login_local(): + return {"ok": True} + + @app.post("/api/v1/auth/register") + async def register(): + return {"ok": True} + + @app.post("/api/threads/abc/runs/stream") + async def protected_mutation(): + return {"ok": True} + + return app + + +def test_auth_post_rejects_cross_origin_browser_request(): + """CSRF-exempt auth routes must not accept hostile browser origins. + + Login/register endpoints intentionally skip the double-submit token because + first-time callers do not have a token yet. They still set an auth session, + so a hostile cross-site form POST must be rejected to avoid login CSRF / + session fixation. + """ + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post( + "/api/v1/auth/login/local", + headers={"Origin": "https://evil.example"}, + ) + + assert response.status_code == 403 + assert response.json()["detail"] == "Cross-site auth request denied." + + +def test_auth_post_allows_same_origin_browser_request(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post( + "/api/v1/auth/login/local", + headers={"Origin": "https://deerflow.example"}, + ) + + assert response.status_code == 200 + assert response.cookies.get("csrf_token") + + +def test_auth_post_rejects_malformed_origin_with_path(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post( + "/api/v1/auth/login/local", + headers={"Origin": "https://deerflow.example/path"}, + ) + + assert response.status_code == 403 + assert response.json()["detail"] == "Cross-site auth request denied." + assert response.cookies.get("csrf_token") is None + + +def test_auth_post_rejects_malformed_origin_with_invalid_port(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post( + "/api/v1/auth/login/local", + headers={"Origin": "https://deerflow.example:bad"}, + ) + + assert response.status_code == 403 + assert response.json()["detail"] == "Cross-site auth request denied." + assert response.cookies.get("csrf_token") is None + + +def test_auth_post_allows_same_origin_default_port_equivalence(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post( + "/api/v1/auth/login/local", + headers={"Origin": "https://deerflow.example:443"}, + ) + + assert response.status_code == 200 + assert response.cookies.get("csrf_token") + + +def test_auth_post_allows_forwarded_same_origin(): + client = TestClient(_make_app(), base_url="http://internal:8000") + + response = client.post( + "/api/v1/auth/login/local", + headers={ + "Origin": "https://deerflow.example", + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "deerflow.example, internal:8000", + }, + ) + + assert response.status_code == 200 + assert response.cookies.get("csrf_token") + + +def test_auth_post_allows_rfc_forwarded_same_origin(): + client = TestClient(_make_app(), base_url="http://internal:8000") + + response = client.post( + "/api/v1/auth/login/local", + headers={ + "Origin": "https://deerflow.example", + "Forwarded": "proto=https;host=deerflow.example", + }, + ) + + assert response.status_code == 200 + assert response.cookies.get("csrf_token") + assert "secure" in response.headers["set-cookie"].lower() + + +def test_auth_post_allows_explicit_configured_origin(monkeypatch): + monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example") + client = TestClient(_make_app(), base_url="https://api.example") + + response = client.post( + "/api/v1/auth/register", + headers={"Origin": "https://app.example"}, + ) + + assert response.status_code == 200 + assert response.cookies.get("csrf_token") + + +def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch): + monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*") + client = TestClient(_make_app(), base_url="https://api.example") + + response = client.post( + "/api/v1/auth/login/local", + headers={"Origin": "https://evil.example"}, + ) + + assert response.status_code == 403 + assert response.json()["detail"] == "Cross-site auth request denied." + + +def test_auth_post_sets_strict_samesite_csrf_cookie(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post( + "/api/v1/auth/login/local", + headers={"Origin": "https://deerflow.example"}, + ) + + assert response.status_code == 200 + set_cookie = response.headers["set-cookie"].lower() + assert "csrf_token=" in set_cookie + assert "samesite=strict" in set_cookie + assert "secure" in set_cookie + + +def test_auth_post_without_origin_still_allows_non_browser_clients(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post("/api/v1/auth/login/local") + + assert response.status_code == 200 + assert response.cookies.get("csrf_token") + + +def test_non_auth_mutation_still_requires_double_submit_token(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post( + "/api/threads/abc/runs/stream", + headers={"Origin": "https://deerflow.example"}, + ) + + assert response.status_code == 403 + assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header." + + +def test_non_auth_mutation_allows_valid_double_submit_token(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + client.cookies.set("csrf_token", "known-token") + + response = client.post( + "/api/threads/abc/runs/stream", + headers={ + "Origin": "https://deerflow.example", + "X-CSRF-Token": "known-token", + }, + ) + + assert response.status_code == 200 + + +def test_non_auth_mutation_rejects_mismatched_double_submit_token(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + client.cookies.set("csrf_token", "cookie-token") + + response = client.post( + "/api/threads/abc/runs/stream", + headers={ + "Origin": "https://deerflow.example", + "X-CSRF-Token": "header-token", + }, + ) + + assert response.status_code == 403 + assert response.json()["detail"] == "CSRF token mismatch."