fix(security): harden auth system and fix run journal logic bug (#2593)

* fix(security): harden auth system and fix run journal logic bug

  - Fix inverted condition in RunJournal.on_chat_model_start that prevented
    first human message capture (not messages → messages)
  - Pre-hash passwords with SHA-256 before bcrypt to avoid silent 72-byte
    truncation vulnerability
  - Move load_dotenv() from module scope into get_auth_config() to prevent
    import-time os.environ mutation breaking test isolation
  - Return generic ‘Invalid token’ instead of exposing specific error
    variants (expired, malformed, invalid_signature) to clients
  - Make @require_auth independently enforce 401 instead of silently
    passing through when AuthMiddleware is absent
  - Rate-limit /setup-status endpoint with per-IP cooldown to mitigate
    initialization-state information leak
  - Document in-process rate limiter limitation for multi-worker deployments

* fix(security): return 429+Retry-After on setup-status rate limit, bound cooldown dict

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/070d0be8-99a5-46c8-85bb-6b81b5284021

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* fix(security): add versioned password hashes with auto-migration on login

  The SHA-256 pre-hash change silently broke verification for any existing
  bcrypt-only password hashes. Introduce a <N>$ prefix scheme so hashes
  are self-describing:

  - v2 (current): bcrypt(b64(sha256(password))) with $ prefix
  - v1 (legacy): plain bcrypt, prefixed $ or bare (no prefix)

  verify_password auto-detects the version and falls back to v1 for older
  hashes. LocalAuthProvider.authenticate() now rehashes legacy hashes to v2
  on successful login via needs_rehash(), so existing users upgrade
  transparently without a dedicated migration step.

* fix(auth): harden verify_password, best-effort rehash, update require_auth docstring, downgrade journal logging

- password.py: wrap bcrypt.checkpw in try/except → return False for malformed/corrupt hashes instead of crashing
- local_provider.py: wrap auto-rehash update_user() in try/except so transient DB errors don't fail valid logins
- authz.py: update require_auth docstring to reflect independent 401 enforcement
- journal.py: downgrade on_chat_model_start from INFO to DEBUG, log only metadata (batch_count, message_counts) instead of full serialized/messages content

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/48c5cf31-a4ab-418a-982a-6343c37bb299

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* fix(auth): address code review - narrow ValueError catch, add rehash warning log, rename num_batches

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/48c5cf31-a4ab-418a-982a-6343c37bb299

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
Willem Jiang 2026-04-28 11:34:07 +08:00 committed by GitHub
parent b8bc4826d8
commit 4e4e4f92a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 245 additions and 22 deletions

View File

@ -4,11 +4,8 @@ import logging
import os
import secrets
from dotenv import load_dotenv
from pydantic import BaseModel, Field
load_dotenv()
logger = logging.getLogger(__name__)
@ -37,6 +34,9 @@ def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
if _auth_config is None:
from dotenv import load_dotenv
load_dotenv()
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = secrets.token_urlsafe(32)

View File

@ -1,10 +1,14 @@
"""Local email/password authentication provider."""
import logging
from app.gateway.auth.models import User
from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth.password import hash_password_async, needs_rehash, verify_password_async
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
logger = logging.getLogger(__name__)
class LocalAuthProvider(AuthProvider):
"""Email/password authentication provider using local database."""
@ -43,6 +47,15 @@ class LocalAuthProvider(AuthProvider):
if not await verify_password_async(password, user.password_hash):
return None
if needs_rehash(user.password_hash):
try:
user.password_hash = await hash_password_async(password)
await self._repo.update_user(user)
except Exception:
# Rehash is an opportunistic upgrade; a transient DB error must not
# prevent an otherwise-valid login from succeeding.
logger.warning("Failed to rehash password for user %s; login will still succeed", user.email, exc_info=True)
return user
async def get_user(self, user_id: str) -> User | None:

View File

@ -1,18 +1,66 @@
"""Password hashing utilities using bcrypt directly."""
"""Password hashing utilities with versioned hash format.
Hash format: ``$dfv<N>$<bcrypt_hash>`` where ``<N>`` is the version.
- **v1** (legacy): ``bcrypt(password)`` plain bcrypt, susceptible to
72-byte silent truncation.
- **v2** (current): ``bcrypt(b64(sha256(password)))`` SHA-256 pre-hash
avoids the 72-byte truncation limit so the full password contributes
to the hash.
Verification auto-detects the version and falls back to v1 for hashes
without a prefix, so existing deployments upgrade transparently on next
login.
"""
import asyncio
import base64
import hashlib
import bcrypt
_CURRENT_VERSION = 2
_PREFIX_V2 = "$dfv2$"
_PREFIX_V1 = "$dfv1$"
def _pre_hash_v2(password: str) -> bytes:
"""SHA-256 pre-hash to bypass bcrypt's 72-byte limit."""
return base64.b64encode(hashlib.sha256(password.encode("utf-8")).digest())
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
"""Hash a password (current version: v2 — SHA-256 + bcrypt)."""
raw = bcrypt.hashpw(_pre_hash_v2(password), bcrypt.gensalt()).decode("utf-8")
return f"{_PREFIX_V2}{raw}"
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
"""Verify a password, auto-detecting the hash version.
Accepts v2 (``$dfv2$``), v1 (``$dfv1$``), and bare bcrypt hashes
(treated as v1 for backward compatibility with pre-versioning data).
"""
try:
if hashed_password.startswith(_PREFIX_V2):
bcrypt_hash = hashed_password[len(_PREFIX_V2) :]
return bcrypt.checkpw(_pre_hash_v2(plain_password), bcrypt_hash.encode("utf-8"))
if hashed_password.startswith(_PREFIX_V1):
bcrypt_hash = hashed_password[len(_PREFIX_V1) :]
else:
bcrypt_hash = hashed_password
return bcrypt.checkpw(plain_password.encode("utf-8"), bcrypt_hash.encode("utf-8"))
except ValueError:
# bcrypt raises ValueError for malformed or corrupt hashes (e.g., invalid salt).
# Fail closed rather than crashing the request.
return False
def needs_rehash(hashed_password: str) -> bool:
"""Return True if the hash uses an older version and should be rehashed."""
return not hashed_password.startswith(_PREFIX_V2)
async def hash_password_async(password: str) -> str:

View File

@ -145,7 +145,11 @@ async def _authenticate(request: Request) -> AuthContext:
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that authenticates the request and sets AuthContext.
"""Decorator that authenticates the request and enforces authentication.
Independently raises HTTP 401 for unauthenticated requests, regardless of
whether ``AuthMiddleware`` is present in the ASGI stack. Sets the resolved
``AuthContext`` on ``request.state.auth`` for downstream handlers.
Must be placed ABOVE other decorators (executes after them).
@ -158,7 +162,8 @@ def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
...
Raises:
ValueError: If 'request' parameter is missing
HTTPException: 401 if the request is unauthenticated.
ValueError: If 'request' parameter is missing.
"""
@functools.wraps(func)
@ -181,6 +186,9 @@ def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
auth_context = await _authenticate(request)
request.state.auth = auth_context
if not auth_context.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
return await func(*args, **kwargs)
return wrapper

View File

@ -73,7 +73,7 @@ async def authenticate(request):
if isinstance(payload, TokenError):
raise Auth.exceptions.HTTPException(
status_code=401,
detail=f"Token error: {payload.value}",
detail="Invalid token",
)
user = await get_local_provider().get_user(payload.sub)

View File

@ -146,7 +146,13 @@ def _set_session_cookie(response: Response, token: str, request: Request) -> Non
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
# In-process dict — not shared across workers.
#
# **Limitation**: with multi-worker deployments (e.g., gunicorn -w N), each
# worker maintains its own lockout table, so an attacker effectively gets
# N × _MAX_LOGIN_ATTEMPTS guesses before being locked out everywhere. For
# production multi-worker setups, replace this with a shared store (Redis,
# database-backed counter) to enforce a true per-IP limit.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
@ -376,9 +382,37 @@ async def get_me(request: Request):
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
_SETUP_STATUS_COOLDOWN: dict[str, float] = {}
_SETUP_STATUS_COOLDOWN_SECONDS = 60
_MAX_TRACKED_SETUP_STATUS_IPS = 10000
@router.get("/setup-status")
async def setup_status():
async def setup_status(request: Request):
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
client_ip = _get_client_ip(request)
now = time.time()
last_check = _SETUP_STATUS_COOLDOWN.get(client_ip, 0)
elapsed = now - last_check
if elapsed < _SETUP_STATUS_COOLDOWN_SECONDS:
retry_after = max(1, int(_SETUP_STATUS_COOLDOWN_SECONDS - elapsed))
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Setup status check is rate limited",
headers={"Retry-After": str(retry_after)},
)
# Evict stale entries when dict grows too large to bound memory usage.
if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS:
cutoff = now - _SETUP_STATUS_COOLDOWN_SECONDS
stale = [k for k, t in _SETUP_STATUS_COOLDOWN.items() if t < cutoff]
for k in stale:
del _SETUP_STATUS_COOLDOWN[k]
# If still too large after evicting expired entries, remove oldest half.
if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS:
by_time = sorted(_SETUP_STATUS_COOLDOWN.items(), key=lambda kv: kv[1])
for k, _ in by_time[: len(by_time) // 2]:
del _SETUP_STATUS_COOLDOWN[k]
_SETUP_STATUS_COOLDOWN[client_ip] = now
admin_count = await get_local_provider().count_admin_users()
return {"needs_setup": admin_count == 0}

View File

@ -138,10 +138,16 @@ class RunJournal(BaseCallbackHandler):
# Mark this run_id as seen so on_llm_end knows not to increment again.
self._cached_prompts[rid] = []
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
logger.debug(
"on_chat_model_start %s: tags=%s num_batches=%d message_counts=%s",
run_id,
tags,
len(messages),
[len(batch) for batch in messages],
)
# Capture the first human message sent to any LLM in this run.
if not self._first_human_msg and not messages:
if not self._first_human_msg and messages:
for batch in messages.reversed():
for m in batch.reversed():
if isinstance(m, HumanMessage) and m.name != "summary":

View File

@ -4,12 +4,14 @@ from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import bcrypt
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
from app.gateway.auth.models import User
from app.gateway.auth.password import needs_rehash
from app.gateway.authz import (
AuthContext,
Permissions,
@ -26,6 +28,7 @@ def test_hash_password_and_verify():
password = "s3cr3tP@ssw0rd!"
hashed = hash_password(password)
assert hashed != password
assert hashed.startswith("$dfv2$")
assert verify_password(password, hashed) is True
assert verify_password("wrongpassword", hashed) is False
@ -47,6 +50,47 @@ def test_verify_password_rejects_empty():
assert verify_password("", hashed) is False
def test_hash_produces_v2_prefix():
"""hash_password output starts with $dfv2$."""
hashed = hash_password("anypassword123")
assert hashed.startswith("$dfv2$")
def test_verify_v1_prefixed_hash():
"""verify_password handles $dfv1$ prefixed hashes (plain bcrypt)."""
password = "legacyP@ssw0rd"
raw_bcrypt = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
v1_hash = f"$dfv1${raw_bcrypt}"
assert verify_password(password, v1_hash) is True
assert verify_password("wrong", v1_hash) is False
def test_verify_bare_bcrypt_hash():
"""verify_password handles bare bcrypt hashes (no prefix) as v1."""
password = "oldstyleP@ss"
raw_bcrypt = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
assert verify_password(password, raw_bcrypt) is True
assert verify_password("wrong", raw_bcrypt) is False
def test_needs_rehash_returns_false_for_v2():
"""v2 hashes do not need rehashing."""
hashed = hash_password("something")
assert needs_rehash(hashed) is False
def test_needs_rehash_returns_true_for_v1():
"""v1-prefixed hashes need rehashing."""
raw = bcrypt.hashpw(b"pw", bcrypt.gensalt()).decode("utf-8")
assert needs_rehash(f"$dfv1${raw}") is True
def test_needs_rehash_returns_true_for_bare_bcrypt():
"""Bare bcrypt hashes (no prefix) need rehashing."""
raw = bcrypt.hashpw(b"pw", bcrypt.gensalt()).decode("utf-8")
assert needs_rehash(raw) is True
# ── JWT ─────────────────────────────────────────────────────────────────────
@ -166,7 +210,7 @@ def test_get_auth_context_set():
def test_require_auth_sets_auth_context():
"""require_auth sets auth context on request from cookie."""
"""require_auth rejects unauthenticated requests with 401."""
from fastapi import Request
app = FastAPI()
@ -178,10 +222,9 @@ def test_require_auth_sets_auth_context():
return {"authenticated": ctx.is_authenticated}
with TestClient(app) as client:
# No cookie → anonymous
# No cookie → 401 (require_auth independently enforces authentication)
response = client.get("/test")
assert response.status_code == 200
assert response.json()["authenticated"] is False
assert response.status_code == 401
def test_require_auth_requires_request_param():
@ -652,3 +695,57 @@ def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
# Cleanup
config_module._auth_config = None
# ── Auto-rehash on login ──────────────────────────────────────────────────
def test_authenticate_auto_rehashes_legacy_hash():
"""authenticate() upgrades a bare bcrypt hash to v2 on successful login."""
import asyncio
from app.gateway.auth.local_provider import LocalAuthProvider
password = "rehashTest123"
user = User(
id=uuid4(),
email="rehash@test.com",
password_hash=bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8"),
)
mock_repo = MagicMock()
mock_repo.get_user_by_email = AsyncMock(return_value=user)
mock_repo.update_user = AsyncMock(return_value=user)
provider = LocalAuthProvider(mock_repo)
result = asyncio.run(provider.authenticate({"email": "rehash@test.com", "password": password}))
assert result is not None
assert result.password_hash.startswith("$dfv2$")
mock_repo.update_user.assert_called_once()
def test_authenticate_skips_rehash_for_v2_hash():
"""authenticate() does NOT rehash when the stored hash is already v2."""
import asyncio
from app.gateway.auth.local_provider import LocalAuthProvider
password = "alreadyv2Pass!"
user = User(
id=uuid4(),
email="v2@test.com",
password_hash=hash_password(password),
)
mock_repo = MagicMock()
mock_repo.get_user_by_email = AsyncMock(return_value=user)
mock_repo.update_user = AsyncMock(return_value=user)
provider = LocalAuthProvider(mock_repo)
result = asyncio.run(provider.authenticate({"email": "v2@test.com", "password": password}))
assert result is not None
mock_repo.update_user.assert_not_called()

View File

@ -22,6 +22,7 @@ _TEST_SECRET = "test-secret-key-initialize-admin-min-32"
def _setup_auth(tmp_path):
"""Fresh SQLite engine + auth config per test."""
from app.gateway import deps
from app.gateway.routers.auth import _SETUP_STATUS_COOLDOWN
from deerflow.persistence.engine import close_engine, init_engine
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
@ -29,11 +30,13 @@ def _setup_auth(tmp_path):
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
deps._cached_local_provider = None
deps._cached_repo = None
_SETUP_STATUS_COOLDOWN.clear()
try:
yield
finally:
deps._cached_local_provider = None
deps._cached_repo = None
_SETUP_STATUS_COOLDOWN.clear()
asyncio.run(close_engine())
@ -163,3 +166,17 @@ def test_setup_status_false_when_only_regular_user_exists(client):
resp = client.get("/api/v1/auth/setup-status")
assert resp.status_code == 200
assert resp.json()["needs_setup"] is True
def test_setup_status_rate_limited_on_second_call(client):
"""Second /setup-status call within the cooldown window returns 429 with Retry-After."""
# First call succeeds.
resp1 = client.get("/api/v1/auth/setup-status")
assert resp1.status_code == 200
# Immediate second call is rate-limited.
resp2 = client.get("/api/v1/auth/setup-status")
assert resp2.status_code == 429
assert "Retry-After" in resp2.headers
retry_after = int(resp2.headers["Retry-After"])
assert 1 <= retry_after <= 60

View File

@ -63,7 +63,7 @@ def test_invalid_jwt_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": "garbage"})))
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
assert "Invalid token" in str(exc.value.detail)
def test_expired_jwt_raises_401():
@ -295,7 +295,7 @@ def test_csrf_post_matching_token_proceeds_to_jwt():
)
# Past CSRF, rejected by JWT decode
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
assert "Invalid token" in str(exc.value.detail)
def test_csrf_put_requires_token():