[security] fix(auth): reject cross-site auth POSTs (#2740)

* fix(security): reject cross-site auth posts

* fix(auth): align secure cookie proxy scheme handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Hinotobi 2026-05-07 07:58:06 +08:00 committed by GitHub
parent 1336872b15
commit 2b0e62f679
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 331 additions and 1 deletions

View File

@ -4,8 +4,10 @@ Per RFC-001:
State-changing operations require CSRF protection.
"""
import os
import secrets
from collections.abc import Callable
from urllib.parse import urlsplit
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
return _request_scheme(request) == "https"
def generate_csrf_token() -> str:
@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
"""Return normalized host[:port], omitting default ports."""
host = hostname.lower()
if ":" in host and not host.startswith("["):
host = f"[{host}]"
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
return host
return f"{host}:{port}"
def _normalize_origin(origin: str) -> str | None:
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
try:
parsed = urlsplit(origin.strip())
port = parsed.port
except ValueError:
return None
scheme = parsed.scheme.lower()
if scheme not in {"http", "https"} or not parsed.hostname:
return None
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
return None
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
def _configured_cors_origins() -> set[str]:
"""Return explicit configured browser origins that may call auth routes."""
origins = set()
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
origin = raw_origin.strip()
if not origin or origin == "*":
continue
normalized = _normalize_origin(origin)
if normalized:
origins.add(normalized)
return origins
def _first_header_value(value: str | None) -> str | None:
"""Return the first value from a comma-separated proxy header."""
if not value:
return None
first = value.split(",", 1)[0].strip()
return first or None
def _forwarded_param(request: Request, name: str) -> str | None:
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
forwarded = _first_header_value(request.headers.get("forwarded"))
if not forwarded:
return None
for part in forwarded.split(";"):
key, sep, value = part.strip().partition("=")
if sep and key.lower() == name:
return value.strip().strip('"') or None
return None
def _request_scheme(request: Request) -> str:
"""Resolve the original request scheme from trusted proxy headers."""
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
return scheme.lower()
def _request_origin(request: Request) -> str | None:
"""Build the origin for the URL the browser is targeting."""
scheme = _request_scheme(request)
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
host = f"{host}:{forwarded_port}"
return _normalize_origin(f"{scheme}://{host}")
def is_allowed_auth_origin(request: Request) -> bool:
"""Allow auth POSTs only from the same origin or explicit configured origins.
Login/register/initialize are exempt from the double-submit token because
first-time browser clients do not have a CSRF token yet. They still create
a session cookie, so browser requests with a hostile Origin header must be
rejected to prevent login CSRF / session fixation. Requests without Origin
are allowed for non-browser clients such as curl and mobile integrations.
"""
origin = request.headers.get("origin")
if not origin:
return True
normalized_origin = _normalize_origin(origin)
if normalized_origin is None:
return False
request_origin = _request_origin(request)
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
return JSONResponse(
status_code=403,
content={"detail": "Cross-site auth request denied."},
)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)

View File

@ -0,0 +1,219 @@
"""Tests for CSRF middleware."""
from fastapi import FastAPI
from starlette.testclient import TestClient
from app.gateway.csrf_middleware import CSRFMiddleware
def _make_app() -> FastAPI:
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/api/v1/auth/login/local")
async def login_local():
return {"ok": True}
@app.post("/api/v1/auth/register")
async def register():
return {"ok": True}
@app.post("/api/threads/abc/runs/stream")
async def protected_mutation():
return {"ok": True}
return app
def test_auth_post_rejects_cross_origin_browser_request():
"""CSRF-exempt auth routes must not accept hostile browser origins.
Login/register endpoints intentionally skip the double-submit token because
first-time callers do not have a token yet. They still set an auth session,
so a hostile cross-site form POST must be rejected to avoid login CSRF /
session fixation.
"""
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_allows_same_origin_browser_request():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_rejects_malformed_origin_with_path():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example/path"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_rejects_malformed_origin_with_invalid_port():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:bad"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_allows_same_origin_default_port_equivalence():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:443"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "deerflow.example, internal:8000",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_rfc_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"Forwarded": "proto=https;host=deerflow.example",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
assert "secure" in response.headers["set-cookie"].lower()
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/register",
headers={"Origin": "https://app.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_sets_strict_samesite_csrf_cookie():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
set_cookie = response.headers["set-cookie"].lower()
assert "csrf_token=" in set_cookie
assert "samesite=strict" in set_cookie
assert "secure" in set_cookie
def test_auth_post_without_origin_still_allows_non_browser_clients():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post("/api/v1/auth/login/local")
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_non_auth_mutation_still_requires_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/threads/abc/runs/stream",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
def test_non_auth_mutation_allows_valid_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "known-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "known-token",
},
)
assert response.status_code == 200
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "cookie-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "header-token",
},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token mismatch."