mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-10 02:38:26 +00:00
[security] fix(auth): reject cross-site auth POSTs (#2740)
* fix(security): reject cross-site auth posts * fix(auth): align secure cookie proxy scheme handling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
1336872b15
commit
2b0e62f679
@ -4,8 +4,10 @@ Per RFC-001:
|
||||
State-changing operations require CSRF protection.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
|
||||
def is_secure_request(request: Request) -> bool:
|
||||
"""Detect whether the original client request was made over HTTPS."""
|
||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||
return _request_scheme(request) == "https"
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
|
||||
"""Return normalized host[:port], omitting default ports."""
|
||||
host = hostname.lower()
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = f"[{host}]"
|
||||
|
||||
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
|
||||
return host
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def _normalize_origin(origin: str) -> str | None:
|
||||
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
|
||||
try:
|
||||
parsed = urlsplit(origin.strip())
|
||||
port = parsed.port
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
scheme = parsed.scheme.lower()
|
||||
if scheme not in {"http", "https"} or not parsed.hostname:
|
||||
return None
|
||||
|
||||
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
|
||||
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
|
||||
return None
|
||||
|
||||
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
|
||||
|
||||
|
||||
def _configured_cors_origins() -> set[str]:
|
||||
"""Return explicit configured browser origins that may call auth routes."""
|
||||
origins = set()
|
||||
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
|
||||
origin = raw_origin.strip()
|
||||
if not origin or origin == "*":
|
||||
continue
|
||||
normalized = _normalize_origin(origin)
|
||||
if normalized:
|
||||
origins.add(normalized)
|
||||
return origins
|
||||
|
||||
|
||||
def _first_header_value(value: str | None) -> str | None:
|
||||
"""Return the first value from a comma-separated proxy header."""
|
||||
if not value:
|
||||
return None
|
||||
first = value.split(",", 1)[0].strip()
|
||||
return first or None
|
||||
|
||||
|
||||
def _forwarded_param(request: Request, name: str) -> str | None:
|
||||
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
|
||||
forwarded = _first_header_value(request.headers.get("forwarded"))
|
||||
if not forwarded:
|
||||
return None
|
||||
|
||||
for part in forwarded.split(";"):
|
||||
key, sep, value = part.strip().partition("=")
|
||||
if sep and key.lower() == name:
|
||||
return value.strip().strip('"') or None
|
||||
return None
|
||||
|
||||
|
||||
def _request_scheme(request: Request) -> str:
|
||||
"""Resolve the original request scheme from trusted proxy headers."""
|
||||
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
|
||||
return scheme.lower()
|
||||
|
||||
|
||||
def _request_origin(request: Request) -> str | None:
|
||||
"""Build the origin for the URL the browser is targeting."""
|
||||
scheme = _request_scheme(request)
|
||||
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
|
||||
|
||||
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
|
||||
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
|
||||
host = f"{host}:{forwarded_port}"
|
||||
|
||||
return _normalize_origin(f"{scheme}://{host}")
|
||||
|
||||
|
||||
def is_allowed_auth_origin(request: Request) -> bool:
|
||||
"""Allow auth POSTs only from the same origin or explicit configured origins.
|
||||
|
||||
Login/register/initialize are exempt from the double-submit token because
|
||||
first-time browser clients do not have a CSRF token yet. They still create
|
||||
a session cookie, so browser requests with a hostile Origin header must be
|
||||
rejected to prevent login CSRF / session fixation. Requests without Origin
|
||||
are allowed for non-browser clients such as curl and mobile integrations.
|
||||
"""
|
||||
origin = request.headers.get("origin")
|
||||
if not origin:
|
||||
return True
|
||||
|
||||
normalized_origin = _normalize_origin(origin)
|
||||
if normalized_origin is None:
|
||||
return False
|
||||
|
||||
request_origin = _request_origin(request)
|
||||
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
_is_auth = is_auth_endpoint(request)
|
||||
|
||||
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Cross-site auth request denied."},
|
||||
)
|
||||
|
||||
if should_check_csrf(request) and not _is_auth:
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
|
||||
219
backend/tests/test_csrf_middleware.py
Normal file
219
backend/tests/test_csrf_middleware.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""Tests for CSRF middleware."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
|
||||
|
||||
def _make_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login_local():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/register")
|
||||
async def register():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def protected_mutation():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_auth_post_rejects_cross_origin_browser_request():
|
||||
"""CSRF-exempt auth routes must not accept hostile browser origins.
|
||||
|
||||
Login/register endpoints intentionally skip the double-submit token because
|
||||
first-time callers do not have a token yet. They still set an auth session,
|
||||
so a hostile cross-site form POST must be rejected to avoid login CSRF /
|
||||
session fixation.
|
||||
"""
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://evil.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
|
||||
|
||||
def test_auth_post_allows_same_origin_browser_request():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_rejects_malformed_origin_with_path():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example/path"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
assert response.cookies.get("csrf_token") is None
|
||||
|
||||
|
||||
def test_auth_post_rejects_malformed_origin_with_invalid_port():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example:bad"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
assert response.cookies.get("csrf_token") is None
|
||||
|
||||
|
||||
def test_auth_post_allows_same_origin_default_port_equivalence():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example:443"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_allows_forwarded_same_origin():
|
||||
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "deerflow.example, internal:8000",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_allows_rfc_forwarded_same_origin():
|
||||
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"Forwarded": "proto=https;host=deerflow.example",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
assert "secure" in response.headers["set-cookie"].lower()
|
||||
|
||||
|
||||
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
|
||||
client = TestClient(_make_app(), base_url="https://api.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
headers={"Origin": "https://app.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
|
||||
client = TestClient(_make_app(), base_url="https://api.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://evil.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||
|
||||
|
||||
def test_auth_post_sets_strict_samesite_csrf_cookie():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
set_cookie = response.headers["set-cookie"].lower()
|
||||
assert "csrf_token=" in set_cookie
|
||||
assert "samesite=strict" in set_cookie
|
||||
assert "secure" in set_cookie
|
||||
|
||||
|
||||
def test_auth_post_without_origin_still_allows_non_browser_clients():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post("/api/v1/auth/login/local")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get("csrf_token")
|
||||
|
||||
|
||||
def test_non_auth_mutation_still_requires_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
|
||||
|
||||
|
||||
def test_non_auth_mutation_allows_valid_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
client.cookies.set("csrf_token", "known-token")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-CSRF-Token": "known-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
client.cookies.set("csrf_token", "cookie-token")
|
||||
|
||||
response = client.post(
|
||||
"/api/threads/abc/runs/stream",
|
||||
headers={
|
||||
"Origin": "https://deerflow.example",
|
||||
"X-CSRF-Token": "header-token",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token mismatch."
|
||||
Loading…
x
Reference in New Issue
Block a user